4 Commits
main ... qwen3

Author SHA1 Message Date
Quentin Fuxa
a540a5fd10 fix simul-kv audio trim bug, add 1.7B v2 alignment heads 2026-03-15 20:45:00 +01:00
Quentin Fuxa
7b08ea74ab add H100 benchmark figures 2026-03-15 19:15:00 +01:00
Quentin Fuxa
b69eaf82be qwen3 simul+kv: optimized streaming with kv cache reuse 2026-03-15 18:30:00 +01:00
Quentin Fuxa
ed503be140 qwen 2026-01-02 23:52:00 +01:00
17 changed files with 8010 additions and 4 deletions

BIN
benchmark_bars_h100.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

BIN
benchmark_latency_h100.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 147 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 100 KiB

BIN
benchmark_scatter_h100.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -266,6 +266,8 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en",
"mlx SS small": (-55, -5),
"voxtral mlx": (10, -14),
"qwen3 0.6B": (10, 8),
"qwen3-mlx 0.6B": (10, -14),
"qwen3-mlx 1.7B": (10, 8),
"fw LA large-v3": (8, -5),
"fw SS large-v3": (8, 5),
}

View File

@@ -51,6 +51,12 @@ try:
except (ImportError, Exception):
pass
try:
import mlx_qwen3_asr # noqa: F401
AVAILABLE_BACKENDS.append("qwen3-mlx")
except ImportError:
pass
BACKEND_CONFIG = {
"whisper": {"model_size": "tiny", "lan": "en"},
"voxtral-mlx": {"backend": "voxtral-mlx", "lan": "en"},
@@ -61,6 +67,7 @@ BACKEND_CONFIG = {
"lan": "en",
"custom_alignment_heads": "scripts/alignment_heads_qwen3_asr_1.7B.json",
},
"qwen3-mlx": {"backend": "qwen3-mlx", "lan": "en"},
}
# Voxtral backends flush all words at once with proportionally-distributed
@@ -70,7 +77,7 @@ BACKEND_CONFIG = {
VOXTRAL_BACKENDS = {"voxtral-mlx", "voxtral-hf"}
# Backends that use batch-flush and may have non-monotonic timestamps
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul"}
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul", "qwen3-mlx"}
def backend_kwargs(backend: str) -> dict:

View File

@@ -0,0 +1,116 @@
"""
Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline.
Converts streaming ASRToken output from SimulStreaming into the JSONL
format expected by the AlignAtt MT agent (iwslt26-sst).
Output format (one JSON per line):
{"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0}
Where:
- text: the emitted word/phrase
- emission_time: wall-clock time when the word was emitted (for compute-aware eval)
- speech_time: timestamp in the audio (for compute-unaware eval)
- is_final: whether this is the last word of a segment/silence boundary
"""
import json
import time
from typing import List, TextIO
from whisperlivekit.timed_objects import ASRToken
class CascadeBridge:
"""Converts ASRToken stream to JSONL for the MT agent."""
def __init__(self, output_file: TextIO = None):
self.output_file = output_file
self.start_time = time.time()
self.entries: List[dict] = []
def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False):
"""Emit a batch of tokens from the STT."""
wall_clock = time.time() - self.start_time
for i, token in enumerate(tokens):
entry = {
"text": token.text.strip(),
"emission_time": round(wall_clock, 3),
"speech_time": round(token.start, 3),
"is_final": is_final and (i == len(tokens) - 1),
}
self.entries.append(entry)
if self.output_file:
self.output_file.write(json.dumps(entry) + "\n")
self.output_file.flush()
def get_entries(self) -> List[dict]:
return self.entries
def get_text(self) -> str:
"""Get the full transcribed text."""
return " ".join(e["text"] for e in self.entries if e["text"])
def save(self, path: str):
"""Save all entries to a JSONL file."""
with open(path, "w") as f:
for entry in self.entries:
f.write(json.dumps(entry) + "\n")
def run_stt_to_jsonl(
audio_path: str,
output_path: str,
model_id: str = "Qwen/Qwen3-ASR-0.6B",
alignment_heads_path: str = None,
border_fraction: float = 0.20,
language: str = "en",
chunk_sec: float = 1.0,
):
"""Run STT on an audio file and save JSONL output for the MT agent.
This is the main entry point for the cascade: audio file → JSONL.
"""
import wave
import numpy as np
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor
# Load audio
with wave.open(audio_path, 'r') as wf:
audio = np.frombuffer(
wf.readframes(wf.getnframes()), dtype=np.int16
).astype(np.float32) / 32768.0
# Initialize STT
asr = Qwen3SimulKVASR(
model_dir=model_id,
lan=language,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
proc = Qwen3SimulKVOnlineProcessor(asr)
bridge = CascadeBridge()
# Stream audio in chunks
chunk_samples = int(chunk_sec * 16000)
offset = 0
stream_time = 0.0
while offset < len(audio):
chunk = audio[offset:offset + chunk_samples]
stream_time += len(chunk) / 16000
proc.insert_audio_chunk(chunk, stream_time)
words, _ = proc.process_iter(is_last=False)
if words:
bridge.emit_tokens(words, is_final=False)
offset += chunk_samples
# Final flush
final_words, _ = proc.finish()
if final_words:
bridge.emit_tokens(final_words, is_final=True)
# Save
bridge.save(output_path)
return bridge

View File

@@ -109,6 +109,16 @@ BACKENDS = [
"streaming": "chunk",
"devices": ["cuda", "mps", "cpu"],
},
{
"id": "qwen3-mlx",
"name": "Qwen3 MLX",
"module": "mlx_qwen3_asr",
"install": "pip install mlx-qwen3-asr",
"description": "Qwen3-ASR on Apple Silicon (MLX, native streaming)",
"platform": "darwin-arm64",
"streaming": "native",
"devices": ["mlx"],
},
{
"id": "openai-api",
"name": "OpenAI API",
@@ -193,6 +203,9 @@ MODEL_CATALOG = [
# Qwen3 ASR
{"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"},
{"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
# Qwen3 MLX (native streaming on Apple Silicon)
{"name": "qwen3-mlx:1.7b", "family": "qwen3-mlx", "params": "1.7B", "disk": "1.8 GB", "languages": 12, "quality": "good", "speed": "fast"},
{"name": "qwen3-mlx:0.6b", "family": "qwen3-mlx", "params": "0.6B", "disk": "0.7 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
]
@@ -310,6 +323,9 @@ def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool:
elif family == "qwen3":
size = name.split(":")[1] if ":" in name else "1.7b"
return QWEN3_REPOS.get(size, "") in downloaded
elif family == "qwen3-mlx":
size = name.split(":")[1] if ":" in name else "1.7b"
return QWEN3_REPOS.get(size, "") in downloaded
return False
@@ -324,6 +340,8 @@ def _best_backend_for_model(model_entry: dict) -> str:
return "voxtral"
elif family == "qwen3":
return "qwen3"
elif family == "qwen3-mlx":
return "qwen3-mlx"
elif family == "whisper":
if is_apple and _module_available("mlx_whisper"):
return "mlx-whisper"
@@ -383,6 +401,8 @@ def cmd_models():
# Skip platform-incompatible models
if name == "voxtral-mlx" and not is_apple_silicon:
continue
if m["family"] == "qwen3-mlx" and not is_apple_silicon:
continue
is_dl = _model_is_downloaded(m, downloaded)
@@ -447,6 +467,18 @@ def _resolve_pull_target(spec: str):
targets.append(("voxtral-mlx", VOXTRAL_MLX_REPO, "Voxtral Mini (MLX)"))
return targets
# Handle qwen3-mlx (must check before generic qwen3)
if backend_part == "qwen3-mlx" or size_part.startswith("qwen3-mlx"):
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
if qwen_size.startswith("qwen3"):
qwen_size = "1.7b" # default
repo = QWEN3_REPOS.get(qwen_size)
if not repo:
print(f" Unknown Qwen3 size: {qwen_size}. Available: {', '.join(QWEN3_REPOS.keys())}")
return []
targets.append(("qwen3-mlx", repo, f"Qwen3-ASR MLX {qwen_size}"))
return targets
# Handle qwen3
if backend_part == "qwen3" or size_part.startswith("qwen3"):
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
@@ -503,7 +535,7 @@ def _resolve_pull_target(spec: str):
else:
print(f" Unknown model: {spec}")
print(f" Available sizes: {', '.join(WHISPER_SIZES)}")
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b")
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b, qwen3-mlx:1.7b, qwen3-mlx:0.6b")
return []
return targets
@@ -986,6 +1018,9 @@ def _resolve_run_spec(spec: str):
if spec == "voxtral-mlx":
return "voxtral-mlx", None
if spec == "qwen3-mlx":
return "qwen3-mlx", None
if spec in WHISPER_SIZES:
return None, spec
@@ -1231,6 +1266,12 @@ def _probe_backend_state(processor) -> dict:
elif hasattr(transcription, "_mlx_processor"):
info["backend_type"] = "voxtral-mlx"
# Qwen3 MLX specifics
elif hasattr(transcription, "_session") and hasattr(transcription, "_state"):
info["backend_type"] = "qwen3-mlx"
info["samples_fed"] = getattr(transcription, "_samples_fed", 0)
info["committed_words"] = getattr(transcription, "_n_committed_words", 0)
# SimulStreaming specifics
elif hasattr(transcription, "prev_output"):
info["backend_type"] = "simulstreaming"

View File

@@ -121,6 +121,20 @@ class TranscriptionEngine:
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
self.tokenizer = None
self.asr = Qwen3MLXASR(**transcription_common_params)
logger.info("Using Qwen3 MLX native backend")
elif config.backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR
self.tokenizer = None
self.asr = Qwen3SimulKVASR(
**transcription_common_params,
alignment_heads_path=config.custom_alignment_heads,
border_fraction=getattr(config, 'border_fraction', 0.25),
)
logger.info("Using Qwen3-ASR backend with SimulStreaming+KV policy")
elif config.backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
self.tokenizer = None
@@ -230,6 +244,12 @@ def online_factory(args, asr, language=None):
if backend == "vllm-realtime":
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
return VLLMRealtimeOnlineProcessor(asr)
if backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
return Qwen3SimulKVOnlineProcessor(asr)
if backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
return Qwen3MLXOnlineProcessor(asr)
if backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor
return Qwen3SimulStreamingOnlineProcessor(asr)

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
)
parser.add_argument(
"--no-vac",

View File

@@ -0,0 +1,392 @@
"""
MLX-accelerated Qwen3-ASR backend for WhisperLiveKit.
Provides ``Qwen3MLXASR`` (model holder) and ``Qwen3MLXOnlineProcessor``
(batch-based processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Uses the ``mlx-qwen3-asr`` package for fast Qwen3 inference on Apple Silicon.
The batch ``session.transcribe()`` API is called on the full accumulated audio
buffer, and LocalAgreement-style diffing (HypothesisBuffer) commits stable
words across consecutive inferences.
"""
import logging
import sys
import time
from typing import List, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
# Whisper language codes -> Qwen3 canonical language names
# (duplicated from qwen3_asr.py to avoid importing torch at module level)
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
# Model size aliases -> HuggingFace model IDs
QWEN3_MLX_MODEL_MAPPING = {
"base": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"1.7b": "Qwen/Qwen3-ASR-1.7B",
"0.6b": "Qwen/Qwen3-ASR-0.6B",
}
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class Qwen3MLXASR:
"""Lightweight model holder -- loads the mlx-qwen3-asr model once and
keeps it alive for the lifetime of the server."""
sep = ""
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
import mlx.core as mx
import mlx_qwen3_asr
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
# Resolve model ID from size aliases or explicit path
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = QWEN3_MLX_MODEL_MAPPING.get(
(model_size or "base").lower(), "Qwen/Qwen3-ASR-0.6B"
)
t0 = time.time()
logger.info("Loading Qwen3 MLX model '%s' ...", model_path)
self.session = mlx_qwen3_asr.Session(model_path, dtype=mx.float16)
logger.info("Qwen3 MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "qwen3-mlx"
self.tokenizer = None
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class Qwen3MLXOnlineProcessor:
"""Batch-based processor that accumulates audio and periodically calls
``session.transcribe()`` on the full buffer.
Uses LocalAgreement-style diffing (HypothesisBuffer) to commit stable
words across consecutive inferences, exactly like the PyTorch Qwen3
backend with ``OnlineASRProcessor``.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) -> process_iter() -> get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: Qwen3MLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self._session = asr.session
lan = asr.original_language
self._language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, "English") if lan else None
# Audio accumulation
self.audio_buffer = np.array([], dtype=np.float32)
self._buffer_time_offset: float = 0.0 # absolute time of audio_buffer[0]
# Throttle: minimum new audio (in samples) before re-running inference
self._min_new_samples: int = int(1.0 * self.SAMPLING_RATE) # 1 second
self._samples_since_last_inference: int = 0
# Buffer trimming — keep buffer short for fast re-transcription.
# The model produces ~0.2x RTF, so 15s buffer = ~3s per call.
self._max_buffer_sec: float = 15.0
self._trim_sec: float = 10.0 # keep this many seconds after trimming
# HypothesisBuffer for LocalAgreement diffing
self._committed: List[ASRToken] = []
self._prev_tokens: List[ASRToken] = [] # previous hypothesis (buffer role)
self._last_committed_time: float = 0.0
# Global time tracking
self._global_time_offset: float = 0.0 # extra offset from silences
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.audio_buffer = np.append(self.audio_buffer, audio)
self._samples_since_last_inference += len(audio)
# -- batch transcription --
def _transcribe_buffer(self) -> List[ASRToken]:
"""Run batch transcription on the full audio buffer and return tokens."""
if len(self.audio_buffer) < 400: # too short for meaningful transcription
return []
t0 = time.time()
try:
result = self._session.transcribe(
self.audio_buffer,
language=self._language,
return_timestamps=True,
)
except Exception as e:
logger.warning("[qwen3-mlx] transcribe error: %s", e, exc_info=True)
return []
dur = time.time() - t0
audio_dur = len(self.audio_buffer) / self.SAMPLING_RATE
logger.debug(
"[qwen3-mlx] transcribed %.1fs audio in %.2fs (%.2fx RTF)",
audio_dur, dur, dur / max(audio_dur, 0.01),
)
text = (result.text or "").strip()
if not text:
return []
# Build tokens from segments (word-level timestamps)
tokens: List[ASRToken] = []
if result.segments:
for i, seg in enumerate(result.segments):
word = seg["text"]
start = self._buffer_time_offset + seg["start"]
end = self._buffer_time_offset + seg["end"]
label = word if i == 0 else " " + word
tokens.append(ASRToken(start=start, end=end, text=label))
else:
# Fallback: estimate timestamps from word count
words = text.split()
step = audio_dur / max(len(words), 1)
for i, w in enumerate(words):
t_start = self._buffer_time_offset + i * step
t_end = self._buffer_time_offset + (i + 1) * step
label = w if i == 0 else " " + w
tokens.append(ASRToken(start=t_start, end=t_end, text=label))
return tokens
def _local_agreement(self, new_tokens: List[ASRToken]) -> List[ASRToken]:
"""LocalAgreement diffing: commit the longest common prefix between
the previous hypothesis (``self._prev_tokens``) and the new tokens.
Before comparing, strips tokens that correspond to already-committed
audio (i.e., tokens whose start time is before ``_last_committed_time``).
Also deduplicates boundary tokens (ngram matching) to avoid re-committing
the tail of the previous committed output.
Returns the newly committed tokens.
"""
# Step 1: Only keep tokens that are roughly "new" (after last committed time)
fresh_tokens = [
t for t in new_tokens
if t.start > self._last_committed_time - 0.1
]
# Step 2: Remove duplicates at the boundary with committed tokens
# (like HypothesisBuffer.insert's ngram dedup)
if fresh_tokens and self._committed:
max_ngram = min(len(self._committed), len(fresh_tokens), 5)
for n in range(1, max_ngram + 1):
committed_ngram = " ".join(
t.text.strip() for t in self._committed[-n:]
)
fresh_ngram = " ".join(
t.text.strip() for t in fresh_tokens[:n]
)
if committed_ngram == fresh_ngram:
fresh_tokens = fresh_tokens[n:]
break
# Step 3: LocalAgreement -- longest common prefix between prev and fresh
committed: List[ASRToken] = []
prev = self._prev_tokens
i = 0
j = 0
while i < len(fresh_tokens) and j < len(prev):
if fresh_tokens[i].text.strip() == prev[j].text.strip():
# Agreement: commit this token (use the new token's timestamps)
committed.append(fresh_tokens[i])
i += 1
j += 1
else:
break
# The remaining fresh tokens become the new "previous hypothesis"
self._prev_tokens = fresh_tokens[i:] if i < len(fresh_tokens) else []
return committed
def _trim_buffer_if_needed(self):
"""Trim the audio buffer if it exceeds max_buffer_sec.
Keeps the last ``_trim_sec`` seconds of audio. Also adjusts
committed token tracking and buffer_time_offset.
"""
buffer_dur = len(self.audio_buffer) / self.SAMPLING_RATE
if buffer_dur <= self._max_buffer_sec:
return
keep_sec = self._trim_sec
keep_samples = int(keep_sec * self.SAMPLING_RATE)
cut_samples = len(self.audio_buffer) - keep_samples
if cut_samples <= 0:
return
cut_sec = cut_samples / self.SAMPLING_RATE
self.audio_buffer = self.audio_buffer[cut_samples:]
self._buffer_time_offset += cut_sec
# Remove committed tokens that are before the new buffer start
self._committed = [
t for t in self._committed if t.end > self._buffer_time_offset
]
logger.debug(
"[qwen3-mlx] trimmed buffer: cut %.1fs, new offset %.1f, buffer %.1fs",
cut_sec, self._buffer_time_offset, len(self.audio_buffer) / self.SAMPLING_RATE,
)
# -- interface methods --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""Process the current audio buffer.
Throttles inference to at least 1s of new audio between calls.
Returns (newly_committed_tokens, audio_processed_upto_time).
"""
try:
# Throttle: skip if not enough new audio since last inference
if (not is_last
and self._samples_since_last_inference < self._min_new_samples):
return [], self.end
self._samples_since_last_inference = 0
# Trim buffer if too long
self._trim_buffer_if_needed()
# Run batch transcription
new_tokens = self._transcribe_buffer()
# LocalAgreement diffing
committed = self._local_agreement(new_tokens)
if committed:
self._committed.extend(committed)
self._last_committed_time = committed[-1].end
return committed, self.end
except Exception as e:
logger.warning("[qwen3-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return the unconfirmed text (the tail of the last hypothesis
that was not committed by LocalAgreement)."""
if not self._prev_tokens:
return Transcript(start=None, end=None, text="")
text = "".join(t.text for t in self._prev_tokens)
start = self._prev_tokens[0].start
end = self._prev_tokens[-1].end
return Transcript(start=start, end=end, text=text)
def _flush_all(self) -> List[ASRToken]:
"""Force a final transcription and commit all remaining words."""
# Run one last transcription on the full buffer
self._samples_since_last_inference = self._min_new_samples # bypass throttle
new_tokens = self._transcribe_buffer()
# Commit everything: first the agreed prefix, then the remainder
committed = self._local_agreement(new_tokens)
# Also commit any remaining buffer tokens
remaining = self._prev_tokens
self._prev_tokens = []
all_new = committed + remaining
if all_new:
self._committed.extend(all_new)
self._last_committed_time = all_new[-1].end
return all_new
def _reset_for_new_utterance(self):
"""Reset buffers for a new utterance, preserving time continuity."""
new_offset = self._buffer_time_offset + len(self.audio_buffer) / self.SAMPLING_RATE
saved_end = self.end
self.audio_buffer = np.array([], dtype=np.float32)
self._buffer_time_offset = new_offset
self._samples_since_last_inference = 0
self._committed = []
self._prev_tokens = []
self.end = saved_end
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush pending words when silence starts.
Unlike other backends, does NOT reset the audio buffer — the model
produces better results re-transcribing the full accumulated audio.
Buffer trimming at 30s handles memory naturally.
"""
words = self._flush_all()
logger.info("[qwen3-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
words = self._flush_all()
logger.info("[qwen3-mlx] finish: flushed %d words", len(words))
return words, self.end

View File

@@ -0,0 +1,790 @@
"""
Qwen3-ASR SimulStreaming with KV cache reuse.
This is an optimized version of qwen3_simul.py that reuses the KV cache
across inference calls, avoiding redundant prefill of prompt + old audio.
Architecture:
1. First call: full prefill (prompt + audio tokens), greedy decode with
alignment-head stopping, save KV cache + generated tokens
2. Subsequent calls: invalidate KV for old audio suffix, prefill only
new audio tokens, continue decoding from saved state
3. Audio encoder caching: reuse embeddings for stable attention windows
This gives ~3-5x speedup over the original generate()-based approach.
"""
import json
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from transformers import DynamicCache
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
@dataclass
class Qwen3SimulKVConfig:
"""Configuration for Qwen3 SimulStreaming with KV cache."""
model_id: str = "Qwen/Qwen3-ASR-1.7B"
alignment_heads_path: Optional[str] = None
language: str = "auto"
border_fraction: float = 0.20
rewind_fraction: float = 0.12
audio_min_len: float = 0.5
audio_max_len: float = 30.0
max_context_tokens: int = 20
init_prompt: Optional[str] = None
max_alignment_heads: int = 10
@dataclass
class _AudioEmbedCache:
"""Cache for audio encoder outputs."""
encoded_samples: int = 0
embeddings: Optional[torch.Tensor] = None
encoded_mel_frames: int = 0
stable_tokens: int = 0
def reset(self):
self.encoded_samples = 0
self.embeddings = None
self.encoded_mel_frames = 0
self.stable_tokens = 0
@dataclass
class Qwen3SimulKVState:
"""Per-session mutable state with KV cache."""
# Audio
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
# KV cache state
kv_cache: Optional[DynamicCache] = None
kv_seq_len: int = 0 # sequence length when KV was saved
prompt_token_count: int = 0 # tokens before audio (system prompt etc)
audio_token_count: int = 0 # audio tokens in the cached KV
generated_token_ids: List[int] = field(default_factory=list)
# Alignment tracking
last_attend_frame: int = -15
committed_text: str = ""
committed_word_count: int = 0
committed_token_ids: List[int] = field(default_factory=list)
# Tracking
first_timestamp: Optional[float] = None
detected_language: Optional[str] = None
last_infer_samples: int = 0
# Audio embedding cache
audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache)
def reset_kv(self):
"""Reset KV cache (e.g., when audio is trimmed from front)."""
self.kv_cache = None
self.kv_seq_len = 0
self.prompt_token_count = 0
self.audio_token_count = 0
self.generated_token_ids = []
# Reset alignment tracking — old frame references are invalid
# after audio is trimmed from the front
self.last_attend_frame = -15
class Qwen3SimulKVASR:
"""
Shared backend for Qwen3-ASR SimulStreaming with KV cache reuse.
"""
sep = ""
def __init__(
self,
model_size: str = None,
model_dir: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
min_chunk_size: float = 0.1,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
model_path: Optional[str] = None,
lora_path: Optional[str] = None,
direct_english_translation: bool = False,
**kwargs,
):
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3SimulKVConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
self._load_model(model_size, model_dir, model_cache_dir, model_path)
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
# Pre-compute heads by layer for efficient hook installation
self.heads_by_layer = {}
for layer_idx, head_idx in self.alignment_heads:
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
self._warmup(audio)
def _load_model(self, model_size, model_dir, model_cache_dir, model_path):
from whisperlivekit.qwen3_asr import QWEN3_MODEL_MAPPING, _patch_transformers_compat
_patch_transformers_compat()
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
if model_dir:
model_id = model_dir
elif model_path:
model_id = model_path
elif model_size:
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
else:
model_id = "Qwen/Qwen3-ASR-1.7B"
if torch.cuda.is_available():
dtype, device = torch.bfloat16, "cuda:0"
else:
dtype, device = torch.float32, "cpu"
logger.info("Loading Qwen3-ASR for SimulStreaming+KV: %s", model_id)
self.model = AutoModel.from_pretrained(model_id, dtype=dtype, device_map=device)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
thinker = self.model.thinker
text_config = thinker.config.text_config
self.num_layers = text_config.num_hidden_layers
self.num_heads = text_config.num_attention_heads
self.num_kv_heads = text_config.num_key_value_heads
self.audio_token_id = thinker.config.audio_token_id
self.device = next(self.model.parameters()).device
self.dtype = next(self.model.parameters()).dtype
self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("<asr_text>")
# EOS tokens
self.eos_ids = {151645, 151643}
if self.processor.tokenizer.eos_token_id is not None:
self.eos_ids.add(self.processor.tokenizer.eos_token_id)
logger.info(
"Qwen3-ASR loaded: %d layers x %d heads, device=%s",
self.num_layers, self.num_heads, self.device,
)
def _load_alignment_heads(self, path):
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info("Loaded top %d alignment heads from %s", len(heads), path)
return heads
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning("No alignment heads file. Using %d default heads.", len(default_heads))
return default_heads[:max_heads]
def _warmup(self, audio):
try:
audio = audio[:SAMPLE_RATE * 2]
msgs = [{"role": "system", "content": ""}, {"role": "user", "content": [{"type": "audio", "audio": ""}]}]
text_prompt = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
inputs = self.processor(text=[text_prompt], audio=[audio], return_tensors="pt", padding=True)
inputs = inputs.to(self.device).to(self.dtype)
with torch.inference_mode():
self.model.thinker.generate(**inputs, max_new_tokens=5, do_sample=False)
logger.info("Warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
pass
class Qwen3SimulKVOnlineProcessor:
"""
Per-session online processor with KV cache reuse.
Key optimization: instead of calling generate() each time (which does
full prefill), we maintain a DynamicCache and do incremental prefill
+ manual greedy decoding with alignment head hooks.
"""
SAMPLING_RATE = 16000
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3SimulKVASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
self.state = Qwen3SimulKVState()
self._build_prompt_template()
def _build_prompt_template(self):
from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE
msgs = [
{"role": "system", "content": ""},
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
]
self._base_prompt = self.asr.processor.apply_chat_template(
msgs, add_generation_prompt=True, tokenize=False,
)
lan = self.asr.cfg.language
if lan and lan != "auto":
lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
self._base_prompt += f"language {lang_name}<asr_text>"
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
self.state.audio_cache.reset()
self.state.reset_kv() # Must invalidate KV when audio is trimmed
def start_silence(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
self.state.audio_buffer = np.append(
self.state.audio_buffer, np.zeros(gap_len, dtype=np.float32),
)
else:
self.state = Qwen3SimulKVState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True)
self.state = Qwen3SimulKVState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
def _encode_audio(self) -> Tuple[torch.Tensor, int]:
"""Encode full audio buffer, with caching for stable windows."""
asr = self.asr
state = self.state
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
feat_out = asr.processor.feature_extractor(
[state.audio_buffer], sampling_rate=16000,
padding=True, truncation=False,
return_attention_mask=True, return_tensors="pt",
)
input_features = feat_out["input_features"].to(asr.device).to(asr.dtype)
feature_attention_mask = feat_out["attention_mask"].to(asr.device)
total_mel_frames = feature_attention_mask.sum().item()
total_audio_tokens = _get_feat_extract_output_lengths(
torch.tensor(total_mel_frames),
).item()
cache = state.audio_cache
audio_cfg = asr.model.thinker.audio_tower.config
n_window_infer = getattr(audio_cfg, "n_window_infer", 400)
n_complete_windows = total_mel_frames // n_window_infer
if n_complete_windows <= 0 or cache.embeddings is None:
# Full encode
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
).item() if stable_mel > 0 else 0
else:
stable_mel = n_complete_windows * n_window_infer
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
).item()
if cache.stable_tokens > 0 and cache.stable_tokens <= stable_tokens:
cached_prefix = cache.embeddings[:stable_tokens] if cache.embeddings.dim() == 2 else cache.embeddings[0, :stable_tokens]
tail_features = input_features[:, :, stable_mel:]
tail_mel_frames = total_mel_frames - stable_mel
if tail_mel_frames > 0:
tail_mask = torch.ones(
(1, tail_features.shape[2]),
dtype=feature_attention_mask.dtype,
device=feature_attention_mask.device,
)
tail_embeds = asr.model.thinker.get_audio_features(
tail_features, feature_attention_mask=tail_mask,
)
if tail_embeds.dim() == 3:
tail_embeds = tail_embeds[0]
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
else:
audio_embeds = cached_prefix
else:
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
# Update cache
cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0]
cache.encoded_samples = len(state.audio_buffer)
cache.encoded_mel_frames = total_mel_frames
stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
cache.stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel_final),
).item() if stable_mel_final > 0 else 0
return audio_embeds, total_audio_tokens
def _build_full_inputs(self, audio_embeds: torch.Tensor) -> dict:
"""Build full input embeddings from prompt + audio embeddings + context."""
asr = self.asr
state = self.state
thinker = asr.model.thinker
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
n_audio_tokens = audio_embeds.shape[0]
prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens(
[self._base_prompt], iter([n_audio_tokens]),
)[0]
text_ids = asr.processor.tokenizer(
[prompt_with_placeholders], return_tensors="pt", padding=True,
)
input_ids = text_ids["input_ids"].to(asr.device)
attention_mask = text_ids.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(asr.device)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
ctx_ids = torch.tensor([ctx], dtype=input_ids.dtype, device=input_ids.device)
input_ids = torch.cat([input_ids, ctx_ids], dim=1)
if attention_mask is not None:
ctx_mask = torch.ones_like(ctx_ids)
attention_mask = torch.cat([attention_mask, ctx_mask], dim=1)
# Build inputs_embeds
inputs_embeds = thinker.get_input_embeddings()(input_ids)
audio_mask = (input_ids == asr.audio_token_id)
n_placeholders = audio_mask.sum().item()
if n_placeholders != n_audio_tokens:
logger.warning("Audio token mismatch: %d vs %d", n_placeholders, n_audio_tokens)
return None
audio_embeds_cast = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(expand_mask, audio_embeds_cast)
# Find audio token range
audio_positions = audio_mask[0].nonzero(as_tuple=True)[0]
audio_start = audio_positions[0].item()
audio_end = audio_positions[-1].item() + 1
return {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"audio_start": audio_start,
"audio_end": audio_end,
"n_audio_tokens": n_audio_tokens,
}
@torch.inference_mode()
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
min_new_seconds = 1.0
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
return [], self.end
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
timestamped_words = self._infer(is_last)
except Exception as e:
logger.exception("Inference error: %s", e)
self.state.reset_kv()
return [], self.end
if not timestamped_words:
return [], self.end
self.buffer = []
return timestamped_words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run inference with KV cache reuse and alignment-head stopping."""
asr = self.asr
state = self.state
thinker = asr.model.thinker
# Step 1: Encode audio (with caching)
audio_embeds, n_audio_tokens_total = self._encode_audio()
# Step 2: Build full inputs
full_inputs = self._build_full_inputs(audio_embeds)
if full_inputs is None:
state.reset_kv()
return []
input_ids = full_inputs["input_ids"]
inputs_embeds = full_inputs["inputs_embeds"]
attention_mask = full_inputs["attention_mask"]
audio_start = full_inputs["audio_start"]
audio_end = full_inputs["audio_end"]
n_audio_tokens = full_inputs["n_audio_tokens"]
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# Step 3: Full prefill (we always re-prefill since audio tokens change)
# Future optimization: partial prefill when only tail audio changes
out = thinker(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
use_cache=True,
)
kv_cache = out.past_key_values
prompt_len = input_ids.shape[1]
# Step 4: Greedy decode with alignment head stopping
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
last_attend_frame = state.last_attend_frame
# Install hooks for alignment head attention extraction
decoder_layers = thinker.model.layers
num_kv_heads = asr.num_kv_heads
num_heads = asr.num_heads
gqa_ratio = num_heads // num_kv_heads
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import apply_rotary_pos_emb
per_step_frames: List[List[int]] = []
current_step_frames: List[int] = []
hooks = []
def _make_attn_hook(layer_idx):
head_indices = asr.heads_by_layer[layer_idx]
def hook_fn(module, args, kwargs, output):
hidden_states = kwargs.get('hidden_states')
if hidden_states is None:
hidden_states = args[0] if args else None
if hidden_states is None or hidden_states.shape[1] != 1:
return
position_embeddings = kwargs.get('position_embeddings')
if position_embeddings is None and len(args) > 1:
position_embeddings = args[1]
past_kv = kwargs.get('past_key_values')
if position_embeddings is None or past_kv is None:
return
hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim)
q = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
cos, sin = position_embeddings
q, _ = apply_rotary_pos_emb(q, q, cos, sin)
cache_layer = past_kv.layers[module.layer_idx]
k = cache_layer.keys
if k is None or audio_end > k.shape[2]:
return
for h_idx in head_indices:
if h_idx >= q.shape[1]:
continue
kv_h_idx = h_idx // gqa_ratio
q_h = q[0, h_idx, 0]
k_audio = k[0, kv_h_idx, audio_start:audio_end]
scores = torch.matmul(k_audio, q_h)
frame = scores.argmax().item()
current_step_frames.append(frame)
return hook_fn
for layer_idx in asr.heads_by_layer:
if layer_idx < len(decoder_layers):
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
_make_attn_hook(layer_idx), with_kwargs=True,
)
hooks.append(h)
try:
# Greedy decoding with alignment-based stopping
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
generated_ids = []
border_stop_step = None
tokens_per_sec = 6
if is_last:
max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120)
else:
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40)
for step in range(max_tokens):
tid = next_token.item()
if tid in asr.eos_ids:
break
generated_ids.append(tid)
# Collect alignment frames for this step
if current_step_frames:
per_step_frames.append(current_step_frames)
current_step_frames = []
# Check stopping criteria (after 3 tokens)
if not is_last and len(per_step_frames) >= 3:
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
break
last_attend_frame = attended
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
break
# Next token
out = thinker(
input_ids=next_token,
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = out.past_key_values
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
# Flush remaining frames
if current_step_frames:
per_step_frames.append(current_step_frames)
finally:
for h in hooks:
h.remove()
state.last_attend_frame = last_attend_frame
if not generated_ids:
return []
# Strip metadata prefix (<asr_text> token)
all_generated = torch.tensor(generated_ids, device=asr.device)
num_gen = len(generated_ids)
asr_text_id = asr.asr_text_token_id
metadata_offset = 0
for i in range(min(num_gen, 10)):
if generated_ids[i] == asr_text_id:
if state.detected_language is None and i > 0:
from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE
prefix_text = asr.processor.tokenizer.decode(
generated_ids[:i], skip_special_tokens=True,
).strip()
parts = prefix_text.split()
if len(parts) >= 2:
lang_name = parts[-1]
if lang_name.lower() != "none":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
lang_name, lang_name.lower(),
)
metadata_offset = i + 1
break
if metadata_offset > 0:
generated_ids = generated_ids[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Determine emit count
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
emitted_ids = generated_ids[:emit_up_to]
if not emitted_ids:
return []
# Build timestamped words
words = self._build_timestamped_words(
emitted_ids, per_step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
state.committed_word_count += len(words)
# Include metadata in committed tokens for context
all_emitted = generated_ids[:emit_up_to]
if metadata_offset > 0:
all_emitted = generated_ids[:emit_up_to] # already stripped
state.committed_token_ids.extend(all_emitted)
return words
def _build_timestamped_words(
self,
generated_ids: list,
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
asr = self.asr
state = self.state
per_token_frame = []
for step in range(emit_up_to):
if step < len(step_frames) and step_frames[step]:
frames = sorted(step_frames[step])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
tokenizer = asr.processor.tokenizer
full_text = tokenizer.decode(generated_ids[:emit_up_to], skip_special_tokens=True)
text_words = full_text.split()
all_frames = [f for f in per_token_frame if f is not None]
words = []
for wi, word in enumerate(text_words):
if all_frames:
frac = wi / max(len(text_words), 1)
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
words.append((word, frame))
tokens = []
for i, (text, frame) in enumerate(words):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(words), 1)) * audio_duration
+ state.cumulative_time_offset
)
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = Qwen3SimulKVState()
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = Qwen3SimulKVState()
def finish(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end