This commit is contained in:
Quentin Fuxa
2026-01-02 23:52:00 +01:00
parent a6a85431f6
commit ed503be140
8 changed files with 454 additions and 4 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 100 KiB

View File

@@ -266,6 +266,8 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en",
"mlx SS small": (-55, -5),
"voxtral mlx": (10, -14),
"qwen3 0.6B": (10, 8),
"qwen3-mlx 0.6B": (10, -14),
"qwen3-mlx 1.7B": (10, 8),
"fw LA large-v3": (8, -5),
"fw SS large-v3": (8, 5),
}

View File

@@ -51,6 +51,12 @@ try:
except (ImportError, Exception):
pass
try:
import mlx_qwen3_asr # noqa: F401
AVAILABLE_BACKENDS.append("qwen3-mlx")
except ImportError:
pass
BACKEND_CONFIG = {
"whisper": {"model_size": "tiny", "lan": "en"},
"voxtral-mlx": {"backend": "voxtral-mlx", "lan": "en"},
@@ -61,6 +67,7 @@ BACKEND_CONFIG = {
"lan": "en",
"custom_alignment_heads": "scripts/alignment_heads_qwen3_asr_1.7B.json",
},
"qwen3-mlx": {"backend": "qwen3-mlx", "lan": "en"},
}
# Voxtral backends flush all words at once with proportionally-distributed
@@ -70,7 +77,7 @@ BACKEND_CONFIG = {
VOXTRAL_BACKENDS = {"voxtral-mlx", "voxtral-hf"}
# Backends that use batch-flush and may have non-monotonic timestamps
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul"}
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul", "qwen3-mlx"}
def backend_kwargs(backend: str) -> dict:

View File

@@ -109,6 +109,16 @@ BACKENDS = [
"streaming": "chunk",
"devices": ["cuda", "mps", "cpu"],
},
{
"id": "qwen3-mlx",
"name": "Qwen3 MLX",
"module": "mlx_qwen3_asr",
"install": "pip install mlx-qwen3-asr",
"description": "Qwen3-ASR on Apple Silicon (MLX, native streaming)",
"platform": "darwin-arm64",
"streaming": "native",
"devices": ["mlx"],
},
{
"id": "openai-api",
"name": "OpenAI API",
@@ -193,6 +203,9 @@ MODEL_CATALOG = [
# Qwen3 ASR
{"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"},
{"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
# Qwen3 MLX (native streaming on Apple Silicon)
{"name": "qwen3-mlx:1.7b", "family": "qwen3-mlx", "params": "1.7B", "disk": "1.8 GB", "languages": 12, "quality": "good", "speed": "fast"},
{"name": "qwen3-mlx:0.6b", "family": "qwen3-mlx", "params": "0.6B", "disk": "0.7 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
]
@@ -310,6 +323,9 @@ def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool:
elif family == "qwen3":
size = name.split(":")[1] if ":" in name else "1.7b"
return QWEN3_REPOS.get(size, "") in downloaded
elif family == "qwen3-mlx":
size = name.split(":")[1] if ":" in name else "1.7b"
return QWEN3_REPOS.get(size, "") in downloaded
return False
@@ -324,6 +340,8 @@ def _best_backend_for_model(model_entry: dict) -> str:
return "voxtral"
elif family == "qwen3":
return "qwen3"
elif family == "qwen3-mlx":
return "qwen3-mlx"
elif family == "whisper":
if is_apple and _module_available("mlx_whisper"):
return "mlx-whisper"
@@ -383,6 +401,8 @@ def cmd_models():
# Skip platform-incompatible models
if name == "voxtral-mlx" and not is_apple_silicon:
continue
if m["family"] == "qwen3-mlx" and not is_apple_silicon:
continue
is_dl = _model_is_downloaded(m, downloaded)
@@ -447,6 +467,18 @@ def _resolve_pull_target(spec: str):
targets.append(("voxtral-mlx", VOXTRAL_MLX_REPO, "Voxtral Mini (MLX)"))
return targets
# Handle qwen3-mlx (must check before generic qwen3)
if backend_part == "qwen3-mlx" or size_part.startswith("qwen3-mlx"):
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
if qwen_size.startswith("qwen3"):
qwen_size = "1.7b" # default
repo = QWEN3_REPOS.get(qwen_size)
if not repo:
print(f" Unknown Qwen3 size: {qwen_size}. Available: {', '.join(QWEN3_REPOS.keys())}")
return []
targets.append(("qwen3-mlx", repo, f"Qwen3-ASR MLX {qwen_size}"))
return targets
# Handle qwen3
if backend_part == "qwen3" or size_part.startswith("qwen3"):
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
@@ -503,7 +535,7 @@ def _resolve_pull_target(spec: str):
else:
print(f" Unknown model: {spec}")
print(f" Available sizes: {', '.join(WHISPER_SIZES)}")
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b")
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b, qwen3-mlx:1.7b, qwen3-mlx:0.6b")
return []
return targets
@@ -986,6 +1018,9 @@ def _resolve_run_spec(spec: str):
if spec == "voxtral-mlx":
return "voxtral-mlx", None
if spec == "qwen3-mlx":
return "qwen3-mlx", None
if spec in WHISPER_SIZES:
return None, spec
@@ -1231,6 +1266,12 @@ def _probe_backend_state(processor) -> dict:
elif hasattr(transcription, "_mlx_processor"):
info["backend_type"] = "voxtral-mlx"
# Qwen3 MLX specifics
elif hasattr(transcription, "_session") and hasattr(transcription, "_state"):
info["backend_type"] = "qwen3-mlx"
info["samples_fed"] = getattr(transcription, "_samples_fed", 0)
info["committed_words"] = getattr(transcription, "_n_committed_words", 0)
# SimulStreaming specifics
elif hasattr(transcription, "prev_output"):
info["backend_type"] = "simulstreaming"

View File

@@ -121,6 +121,11 @@ class TranscriptionEngine:
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
self.tokenizer = None
self.asr = Qwen3MLXASR(**transcription_common_params)
logger.info("Using Qwen3 MLX native backend")
elif config.backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
self.tokenizer = None
@@ -230,6 +235,9 @@ def online_factory(args, asr, language=None):
if backend == "vllm-realtime":
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
return VLLMRealtimeOnlineProcessor(asr)
if backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
return Qwen3MLXOnlineProcessor(asr)
if backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor
return Qwen3SimulStreamingOnlineProcessor(asr)

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
)
parser.add_argument(
"--no-vac",

View File

@@ -0,0 +1,392 @@
"""
MLX-accelerated Qwen3-ASR backend for WhisperLiveKit.
Provides ``Qwen3MLXASR`` (model holder) and ``Qwen3MLXOnlineProcessor``
(batch-based processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Uses the ``mlx-qwen3-asr`` package for fast Qwen3 inference on Apple Silicon.
The batch ``session.transcribe()`` API is called on the full accumulated audio
buffer, and LocalAgreement-style diffing (HypothesisBuffer) commits stable
words across consecutive inferences.
"""
import logging
import sys
import time
from typing import List, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
# Whisper language codes -> Qwen3 canonical language names
# (duplicated from qwen3_asr.py to avoid importing torch at module level)
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
# Model size aliases -> HuggingFace model IDs
QWEN3_MLX_MODEL_MAPPING = {
"base": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"1.7b": "Qwen/Qwen3-ASR-1.7B",
"0.6b": "Qwen/Qwen3-ASR-0.6B",
}
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class Qwen3MLXASR:
"""Lightweight model holder -- loads the mlx-qwen3-asr model once and
keeps it alive for the lifetime of the server."""
sep = ""
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
import mlx.core as mx
import mlx_qwen3_asr
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
# Resolve model ID from size aliases or explicit path
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = QWEN3_MLX_MODEL_MAPPING.get(
(model_size or "base").lower(), "Qwen/Qwen3-ASR-0.6B"
)
t0 = time.time()
logger.info("Loading Qwen3 MLX model '%s' ...", model_path)
self.session = mlx_qwen3_asr.Session(model_path, dtype=mx.float16)
logger.info("Qwen3 MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "qwen3-mlx"
self.tokenizer = None
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class Qwen3MLXOnlineProcessor:
"""Batch-based processor that accumulates audio and periodically calls
``session.transcribe()`` on the full buffer.
Uses LocalAgreement-style diffing (HypothesisBuffer) to commit stable
words across consecutive inferences, exactly like the PyTorch Qwen3
backend with ``OnlineASRProcessor``.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) -> process_iter() -> get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: Qwen3MLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self._session = asr.session
lan = asr.original_language
self._language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, "English") if lan else None
# Audio accumulation
self.audio_buffer = np.array([], dtype=np.float32)
self._buffer_time_offset: float = 0.0 # absolute time of audio_buffer[0]
# Throttle: minimum new audio (in samples) before re-running inference
self._min_new_samples: int = int(1.0 * self.SAMPLING_RATE) # 1 second
self._samples_since_last_inference: int = 0
# Buffer trimming — keep buffer short for fast re-transcription.
# The model produces ~0.2x RTF, so 15s buffer = ~3s per call.
self._max_buffer_sec: float = 15.0
self._trim_sec: float = 10.0 # keep this many seconds after trimming
# HypothesisBuffer for LocalAgreement diffing
self._committed: List[ASRToken] = []
self._prev_tokens: List[ASRToken] = [] # previous hypothesis (buffer role)
self._last_committed_time: float = 0.0
# Global time tracking
self._global_time_offset: float = 0.0 # extra offset from silences
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.audio_buffer = np.append(self.audio_buffer, audio)
self._samples_since_last_inference += len(audio)
# -- batch transcription --
def _transcribe_buffer(self) -> List[ASRToken]:
"""Run batch transcription on the full audio buffer and return tokens."""
if len(self.audio_buffer) < 400: # too short for meaningful transcription
return []
t0 = time.time()
try:
result = self._session.transcribe(
self.audio_buffer,
language=self._language,
return_timestamps=True,
)
except Exception as e:
logger.warning("[qwen3-mlx] transcribe error: %s", e, exc_info=True)
return []
dur = time.time() - t0
audio_dur = len(self.audio_buffer) / self.SAMPLING_RATE
logger.debug(
"[qwen3-mlx] transcribed %.1fs audio in %.2fs (%.2fx RTF)",
audio_dur, dur, dur / max(audio_dur, 0.01),
)
text = (result.text or "").strip()
if not text:
return []
# Build tokens from segments (word-level timestamps)
tokens: List[ASRToken] = []
if result.segments:
for i, seg in enumerate(result.segments):
word = seg["text"]
start = self._buffer_time_offset + seg["start"]
end = self._buffer_time_offset + seg["end"]
label = word if i == 0 else " " + word
tokens.append(ASRToken(start=start, end=end, text=label))
else:
# Fallback: estimate timestamps from word count
words = text.split()
step = audio_dur / max(len(words), 1)
for i, w in enumerate(words):
t_start = self._buffer_time_offset + i * step
t_end = self._buffer_time_offset + (i + 1) * step
label = w if i == 0 else " " + w
tokens.append(ASRToken(start=t_start, end=t_end, text=label))
return tokens
def _local_agreement(self, new_tokens: List[ASRToken]) -> List[ASRToken]:
"""LocalAgreement diffing: commit the longest common prefix between
the previous hypothesis (``self._prev_tokens``) and the new tokens.
Before comparing, strips tokens that correspond to already-committed
audio (i.e., tokens whose start time is before ``_last_committed_time``).
Also deduplicates boundary tokens (ngram matching) to avoid re-committing
the tail of the previous committed output.
Returns the newly committed tokens.
"""
# Step 1: Only keep tokens that are roughly "new" (after last committed time)
fresh_tokens = [
t for t in new_tokens
if t.start > self._last_committed_time - 0.1
]
# Step 2: Remove duplicates at the boundary with committed tokens
# (like HypothesisBuffer.insert's ngram dedup)
if fresh_tokens and self._committed:
max_ngram = min(len(self._committed), len(fresh_tokens), 5)
for n in range(1, max_ngram + 1):
committed_ngram = " ".join(
t.text.strip() for t in self._committed[-n:]
)
fresh_ngram = " ".join(
t.text.strip() for t in fresh_tokens[:n]
)
if committed_ngram == fresh_ngram:
fresh_tokens = fresh_tokens[n:]
break
# Step 3: LocalAgreement -- longest common prefix between prev and fresh
committed: List[ASRToken] = []
prev = self._prev_tokens
i = 0
j = 0
while i < len(fresh_tokens) and j < len(prev):
if fresh_tokens[i].text.strip() == prev[j].text.strip():
# Agreement: commit this token (use the new token's timestamps)
committed.append(fresh_tokens[i])
i += 1
j += 1
else:
break
# The remaining fresh tokens become the new "previous hypothesis"
self._prev_tokens = fresh_tokens[i:] if i < len(fresh_tokens) else []
return committed
def _trim_buffer_if_needed(self):
"""Trim the audio buffer if it exceeds max_buffer_sec.
Keeps the last ``_trim_sec`` seconds of audio. Also adjusts
committed token tracking and buffer_time_offset.
"""
buffer_dur = len(self.audio_buffer) / self.SAMPLING_RATE
if buffer_dur <= self._max_buffer_sec:
return
keep_sec = self._trim_sec
keep_samples = int(keep_sec * self.SAMPLING_RATE)
cut_samples = len(self.audio_buffer) - keep_samples
if cut_samples <= 0:
return
cut_sec = cut_samples / self.SAMPLING_RATE
self.audio_buffer = self.audio_buffer[cut_samples:]
self._buffer_time_offset += cut_sec
# Remove committed tokens that are before the new buffer start
self._committed = [
t for t in self._committed if t.end > self._buffer_time_offset
]
logger.debug(
"[qwen3-mlx] trimmed buffer: cut %.1fs, new offset %.1f, buffer %.1fs",
cut_sec, self._buffer_time_offset, len(self.audio_buffer) / self.SAMPLING_RATE,
)
# -- interface methods --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""Process the current audio buffer.
Throttles inference to at least 1s of new audio between calls.
Returns (newly_committed_tokens, audio_processed_upto_time).
"""
try:
# Throttle: skip if not enough new audio since last inference
if (not is_last
and self._samples_since_last_inference < self._min_new_samples):
return [], self.end
self._samples_since_last_inference = 0
# Trim buffer if too long
self._trim_buffer_if_needed()
# Run batch transcription
new_tokens = self._transcribe_buffer()
# LocalAgreement diffing
committed = self._local_agreement(new_tokens)
if committed:
self._committed.extend(committed)
self._last_committed_time = committed[-1].end
return committed, self.end
except Exception as e:
logger.warning("[qwen3-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return the unconfirmed text (the tail of the last hypothesis
that was not committed by LocalAgreement)."""
if not self._prev_tokens:
return Transcript(start=None, end=None, text="")
text = "".join(t.text for t in self._prev_tokens)
start = self._prev_tokens[0].start
end = self._prev_tokens[-1].end
return Transcript(start=start, end=end, text=text)
def _flush_all(self) -> List[ASRToken]:
"""Force a final transcription and commit all remaining words."""
# Run one last transcription on the full buffer
self._samples_since_last_inference = self._min_new_samples # bypass throttle
new_tokens = self._transcribe_buffer()
# Commit everything: first the agreed prefix, then the remainder
committed = self._local_agreement(new_tokens)
# Also commit any remaining buffer tokens
remaining = self._prev_tokens
self._prev_tokens = []
all_new = committed + remaining
if all_new:
self._committed.extend(all_new)
self._last_committed_time = all_new[-1].end
return all_new
def _reset_for_new_utterance(self):
"""Reset buffers for a new utterance, preserving time continuity."""
new_offset = self._buffer_time_offset + len(self.audio_buffer) / self.SAMPLING_RATE
saved_end = self.end
self.audio_buffer = np.array([], dtype=np.float32)
self._buffer_time_offset = new_offset
self._samples_since_last_inference = 0
self._committed = []
self._prev_tokens = []
self.end = saved_end
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush pending words when silence starts.
Unlike other backends, does NOT reset the audio buffer — the model
produces better results re-transcribing the full accumulated audio.
Buffer trimming at 30s handles memory naturally.
"""
words = self._flush_all()
logger.info("[qwen3-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
words = self._flush_all()
logger.info("[qwen3-mlx] finish: flushed %d words", len(words))
return words, self.end