Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a540a5fd10 | ||
|
|
7b08ea74ab | ||
|
|
b69eaf82be | ||
|
|
ed503be140 |
BIN
benchmark_bars_h100.png
Normal file
|
After Width: | Height: | Size: 193 KiB |
BIN
benchmark_latency_h100.png
Normal file
|
After Width: | Height: | Size: 84 KiB |
BIN
benchmark_robustness_h100.png
Normal file
|
After Width: | Height: | Size: 101 KiB |
|
Before Width: | Height: | Size: 95 KiB After Width: | Height: | Size: 100 KiB |
BIN
benchmark_scatter_en_h100.png
Normal file
|
After Width: | Height: | Size: 147 KiB |
|
Before Width: | Height: | Size: 95 KiB After Width: | Height: | Size: 100 KiB |
BIN
benchmark_scatter_h100.png
Normal file
|
After Width: | Height: | Size: 204 KiB |
3346
scripts/alignment_heads_qwen3_asr_0.6B.json
Normal file
3292
scripts/alignment_heads_qwen3_asr_1.7B_v2.json
Normal file
@@ -266,6 +266,8 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en",
|
||||
"mlx SS small": (-55, -5),
|
||||
"voxtral mlx": (10, -14),
|
||||
"qwen3 0.6B": (10, 8),
|
||||
"qwen3-mlx 0.6B": (10, -14),
|
||||
"qwen3-mlx 1.7B": (10, 8),
|
||||
"fw LA large-v3": (8, -5),
|
||||
"fw SS large-v3": (8, 5),
|
||||
}
|
||||
|
||||
@@ -51,6 +51,12 @@ try:
|
||||
except (ImportError, Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
import mlx_qwen3_asr # noqa: F401
|
||||
AVAILABLE_BACKENDS.append("qwen3-mlx")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
BACKEND_CONFIG = {
|
||||
"whisper": {"model_size": "tiny", "lan": "en"},
|
||||
"voxtral-mlx": {"backend": "voxtral-mlx", "lan": "en"},
|
||||
@@ -61,6 +67,7 @@ BACKEND_CONFIG = {
|
||||
"lan": "en",
|
||||
"custom_alignment_heads": "scripts/alignment_heads_qwen3_asr_1.7B.json",
|
||||
},
|
||||
"qwen3-mlx": {"backend": "qwen3-mlx", "lan": "en"},
|
||||
}
|
||||
|
||||
# Voxtral backends flush all words at once with proportionally-distributed
|
||||
@@ -70,7 +77,7 @@ BACKEND_CONFIG = {
|
||||
VOXTRAL_BACKENDS = {"voxtral-mlx", "voxtral-hf"}
|
||||
|
||||
# Backends that use batch-flush and may have non-monotonic timestamps
|
||||
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul"}
|
||||
BATCH_FLUSH_BACKENDS = {"voxtral-mlx", "voxtral-hf", "qwen3", "qwen3-simul", "qwen3-mlx"}
|
||||
|
||||
|
||||
def backend_kwargs(backend: str) -> dict:
|
||||
|
||||
116
whisperlivekit/cascade_bridge.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline.
|
||||
|
||||
Converts streaming ASRToken output from SimulStreaming into the JSONL
|
||||
format expected by the AlignAtt MT agent (iwslt26-sst).
|
||||
|
||||
Output format (one JSON per line):
|
||||
{"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0}
|
||||
|
||||
Where:
|
||||
- text: the emitted word/phrase
|
||||
- emission_time: wall-clock time when the word was emitted (for compute-aware eval)
|
||||
- speech_time: timestamp in the audio (for compute-unaware eval)
|
||||
- is_final: whether this is the last word of a segment/silence boundary
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, TextIO
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
|
||||
|
||||
class CascadeBridge:
|
||||
"""Converts ASRToken stream to JSONL for the MT agent."""
|
||||
|
||||
def __init__(self, output_file: TextIO = None):
|
||||
self.output_file = output_file
|
||||
self.start_time = time.time()
|
||||
self.entries: List[dict] = []
|
||||
|
||||
def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False):
|
||||
"""Emit a batch of tokens from the STT."""
|
||||
wall_clock = time.time() - self.start_time
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
entry = {
|
||||
"text": token.text.strip(),
|
||||
"emission_time": round(wall_clock, 3),
|
||||
"speech_time": round(token.start, 3),
|
||||
"is_final": is_final and (i == len(tokens) - 1),
|
||||
}
|
||||
self.entries.append(entry)
|
||||
if self.output_file:
|
||||
self.output_file.write(json.dumps(entry) + "\n")
|
||||
self.output_file.flush()
|
||||
|
||||
def get_entries(self) -> List[dict]:
|
||||
return self.entries
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""Get the full transcribed text."""
|
||||
return " ".join(e["text"] for e in self.entries if e["text"])
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save all entries to a JSONL file."""
|
||||
with open(path, "w") as f:
|
||||
for entry in self.entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
def run_stt_to_jsonl(
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
model_id: str = "Qwen/Qwen3-ASR-0.6B",
|
||||
alignment_heads_path: str = None,
|
||||
border_fraction: float = 0.20,
|
||||
language: str = "en",
|
||||
chunk_sec: float = 1.0,
|
||||
):
|
||||
"""Run STT on an audio file and save JSONL output for the MT agent.
|
||||
|
||||
This is the main entry point for the cascade: audio file → JSONL.
|
||||
"""
|
||||
import wave
|
||||
import numpy as np
|
||||
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor
|
||||
|
||||
# Load audio
|
||||
with wave.open(audio_path, 'r') as wf:
|
||||
audio = np.frombuffer(
|
||||
wf.readframes(wf.getnframes()), dtype=np.int16
|
||||
).astype(np.float32) / 32768.0
|
||||
|
||||
# Initialize STT
|
||||
asr = Qwen3SimulKVASR(
|
||||
model_dir=model_id,
|
||||
lan=language,
|
||||
alignment_heads_path=alignment_heads_path,
|
||||
border_fraction=border_fraction,
|
||||
)
|
||||
proc = Qwen3SimulKVOnlineProcessor(asr)
|
||||
bridge = CascadeBridge()
|
||||
|
||||
# Stream audio in chunks
|
||||
chunk_samples = int(chunk_sec * 16000)
|
||||
offset = 0
|
||||
stream_time = 0.0
|
||||
|
||||
while offset < len(audio):
|
||||
chunk = audio[offset:offset + chunk_samples]
|
||||
stream_time += len(chunk) / 16000
|
||||
proc.insert_audio_chunk(chunk, stream_time)
|
||||
words, _ = proc.process_iter(is_last=False)
|
||||
if words:
|
||||
bridge.emit_tokens(words, is_final=False)
|
||||
offset += chunk_samples
|
||||
|
||||
# Final flush
|
||||
final_words, _ = proc.finish()
|
||||
if final_words:
|
||||
bridge.emit_tokens(final_words, is_final=True)
|
||||
|
||||
# Save
|
||||
bridge.save(output_path)
|
||||
return bridge
|
||||
@@ -109,6 +109,16 @@ BACKENDS = [
|
||||
"streaming": "chunk",
|
||||
"devices": ["cuda", "mps", "cpu"],
|
||||
},
|
||||
{
|
||||
"id": "qwen3-mlx",
|
||||
"name": "Qwen3 MLX",
|
||||
"module": "mlx_qwen3_asr",
|
||||
"install": "pip install mlx-qwen3-asr",
|
||||
"description": "Qwen3-ASR on Apple Silicon (MLX, native streaming)",
|
||||
"platform": "darwin-arm64",
|
||||
"streaming": "native",
|
||||
"devices": ["mlx"],
|
||||
},
|
||||
{
|
||||
"id": "openai-api",
|
||||
"name": "OpenAI API",
|
||||
@@ -193,6 +203,9 @@ MODEL_CATALOG = [
|
||||
# Qwen3 ASR
|
||||
{"name": "qwen3:1.7b", "family": "qwen3", "params": "1.7B", "disk": "3.6 GB", "languages": 12, "quality": "good", "speed": "fast"},
|
||||
{"name": "qwen3:0.6b", "family": "qwen3", "params": "0.6B", "disk": "1.4 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
|
||||
# Qwen3 MLX (native streaming on Apple Silicon)
|
||||
{"name": "qwen3-mlx:1.7b", "family": "qwen3-mlx", "params": "1.7B", "disk": "1.8 GB", "languages": 12, "quality": "good", "speed": "fast"},
|
||||
{"name": "qwen3-mlx:0.6b", "family": "qwen3-mlx", "params": "0.6B", "disk": "0.7 GB", "languages": 12, "quality": "fair", "speed": "fastest"},
|
||||
]
|
||||
|
||||
|
||||
@@ -310,6 +323,9 @@ def _model_is_downloaded(model_entry: dict, downloaded: dict) -> bool:
|
||||
elif family == "qwen3":
|
||||
size = name.split(":")[1] if ":" in name else "1.7b"
|
||||
return QWEN3_REPOS.get(size, "") in downloaded
|
||||
elif family == "qwen3-mlx":
|
||||
size = name.split(":")[1] if ":" in name else "1.7b"
|
||||
return QWEN3_REPOS.get(size, "") in downloaded
|
||||
return False
|
||||
|
||||
|
||||
@@ -324,6 +340,8 @@ def _best_backend_for_model(model_entry: dict) -> str:
|
||||
return "voxtral"
|
||||
elif family == "qwen3":
|
||||
return "qwen3"
|
||||
elif family == "qwen3-mlx":
|
||||
return "qwen3-mlx"
|
||||
elif family == "whisper":
|
||||
if is_apple and _module_available("mlx_whisper"):
|
||||
return "mlx-whisper"
|
||||
@@ -383,6 +401,8 @@ def cmd_models():
|
||||
# Skip platform-incompatible models
|
||||
if name == "voxtral-mlx" and not is_apple_silicon:
|
||||
continue
|
||||
if m["family"] == "qwen3-mlx" and not is_apple_silicon:
|
||||
continue
|
||||
|
||||
is_dl = _model_is_downloaded(m, downloaded)
|
||||
|
||||
@@ -447,6 +467,18 @@ def _resolve_pull_target(spec: str):
|
||||
targets.append(("voxtral-mlx", VOXTRAL_MLX_REPO, "Voxtral Mini (MLX)"))
|
||||
return targets
|
||||
|
||||
# Handle qwen3-mlx (must check before generic qwen3)
|
||||
if backend_part == "qwen3-mlx" or size_part.startswith("qwen3-mlx"):
|
||||
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
|
||||
if qwen_size.startswith("qwen3"):
|
||||
qwen_size = "1.7b" # default
|
||||
repo = QWEN3_REPOS.get(qwen_size)
|
||||
if not repo:
|
||||
print(f" Unknown Qwen3 size: {qwen_size}. Available: {', '.join(QWEN3_REPOS.keys())}")
|
||||
return []
|
||||
targets.append(("qwen3-mlx", repo, f"Qwen3-ASR MLX {qwen_size}"))
|
||||
return targets
|
||||
|
||||
# Handle qwen3
|
||||
if backend_part == "qwen3" or size_part.startswith("qwen3"):
|
||||
qwen_size = size_part.split(":")[-1] if ":" in spec else "1.7b"
|
||||
@@ -503,7 +535,7 @@ def _resolve_pull_target(spec: str):
|
||||
else:
|
||||
print(f" Unknown model: {spec}")
|
||||
print(f" Available sizes: {', '.join(WHISPER_SIZES)}")
|
||||
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b")
|
||||
print(" Other models: voxtral, voxtral-mlx, qwen3:1.7b, qwen3:0.6b, qwen3-mlx:1.7b, qwen3-mlx:0.6b")
|
||||
return []
|
||||
|
||||
return targets
|
||||
@@ -986,6 +1018,9 @@ def _resolve_run_spec(spec: str):
|
||||
if spec == "voxtral-mlx":
|
||||
return "voxtral-mlx", None
|
||||
|
||||
if spec == "qwen3-mlx":
|
||||
return "qwen3-mlx", None
|
||||
|
||||
if spec in WHISPER_SIZES:
|
||||
return None, spec
|
||||
|
||||
@@ -1231,6 +1266,12 @@ def _probe_backend_state(processor) -> dict:
|
||||
elif hasattr(transcription, "_mlx_processor"):
|
||||
info["backend_type"] = "voxtral-mlx"
|
||||
|
||||
# Qwen3 MLX specifics
|
||||
elif hasattr(transcription, "_session") and hasattr(transcription, "_state"):
|
||||
info["backend_type"] = "qwen3-mlx"
|
||||
info["samples_fed"] = getattr(transcription, "_samples_fed", 0)
|
||||
info["committed_words"] = getattr(transcription, "_n_committed_words", 0)
|
||||
|
||||
# SimulStreaming specifics
|
||||
elif hasattr(transcription, "prev_output"):
|
||||
info["backend_type"] = "simulstreaming"
|
||||
|
||||
@@ -121,6 +121,20 @@ class TranscriptionEngine:
|
||||
self.tokenizer = None
|
||||
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||
elif config.backend == "qwen3-mlx":
|
||||
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
|
||||
self.tokenizer = None
|
||||
self.asr = Qwen3MLXASR(**transcription_common_params)
|
||||
logger.info("Using Qwen3 MLX native backend")
|
||||
elif config.backend == "qwen3-simul-kv":
|
||||
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR
|
||||
self.tokenizer = None
|
||||
self.asr = Qwen3SimulKVASR(
|
||||
**transcription_common_params,
|
||||
alignment_heads_path=config.custom_alignment_heads,
|
||||
border_fraction=getattr(config, 'border_fraction', 0.25),
|
||||
)
|
||||
logger.info("Using Qwen3-ASR backend with SimulStreaming+KV policy")
|
||||
elif config.backend == "qwen3-simul":
|
||||
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
|
||||
self.tokenizer = None
|
||||
@@ -230,6 +244,12 @@ def online_factory(args, asr, language=None):
|
||||
if backend == "vllm-realtime":
|
||||
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
|
||||
return VLLMRealtimeOnlineProcessor(asr)
|
||||
if backend == "qwen3-simul-kv":
|
||||
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
|
||||
return Qwen3SimulKVOnlineProcessor(asr)
|
||||
if backend == "qwen3-mlx":
|
||||
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
|
||||
return Qwen3MLXOnlineProcessor(asr)
|
||||
if backend == "qwen3-simul":
|
||||
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingOnlineProcessor
|
||||
return Qwen3SimulStreamingOnlineProcessor(asr)
|
||||
|
||||
@@ -147,8 +147,8 @@ def parse_args():
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-simul", "vllm-realtime"],
|
||||
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
|
||||
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
|
||||
392
whisperlivekit/qwen3_mlx_asr.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
MLX-accelerated Qwen3-ASR backend for WhisperLiveKit.
|
||||
|
||||
Provides ``Qwen3MLXASR`` (model holder) and ``Qwen3MLXOnlineProcessor``
|
||||
(batch-based processor) that plug into WhisperLiveKit's audio processing
|
||||
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
|
||||
|
||||
Uses the ``mlx-qwen3-asr`` package for fast Qwen3 inference on Apple Silicon.
|
||||
The batch ``session.transcribe()`` API is called on the full accumulated audio
|
||||
buffer, and LocalAgreement-style diffing (HypothesisBuffer) commits stable
|
||||
words across consecutive inferences.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Whisper language codes -> Qwen3 canonical language names
|
||||
# (duplicated from qwen3_asr.py to avoid importing torch at module level)
|
||||
WHISPER_TO_QWEN3_LANGUAGE = {
|
||||
"zh": "Chinese", "en": "English", "yue": "Cantonese",
|
||||
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
|
||||
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
|
||||
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
|
||||
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
|
||||
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
|
||||
"pl": "Polish", "cs": "Czech", "fa": "Persian",
|
||||
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
|
||||
}
|
||||
|
||||
# Model size aliases -> HuggingFace model IDs
|
||||
QWEN3_MLX_MODEL_MAPPING = {
|
||||
"base": "Qwen/Qwen3-ASR-0.6B",
|
||||
"tiny": "Qwen/Qwen3-ASR-0.6B",
|
||||
"small": "Qwen/Qwen3-ASR-0.6B",
|
||||
"large": "Qwen/Qwen3-ASR-1.7B",
|
||||
"medium": "Qwen/Qwen3-ASR-1.7B",
|
||||
"large-v3": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
"1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model holder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3MLXASR:
|
||||
"""Lightweight model holder -- loads the mlx-qwen3-asr model once and
|
||||
keeps it alive for the lifetime of the server."""
|
||||
|
||||
sep = ""
|
||||
SAMPLING_RATE = 16_000
|
||||
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
import mlx.core as mx
|
||||
import mlx_qwen3_asr
|
||||
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
lan = kwargs.get("lan", "auto")
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
|
||||
# Resolve model ID from size aliases or explicit path
|
||||
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||
if not model_path:
|
||||
model_size = kwargs.get("model_size", "")
|
||||
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||
model_path = model_size
|
||||
else:
|
||||
model_path = QWEN3_MLX_MODEL_MAPPING.get(
|
||||
(model_size or "base").lower(), "Qwen/Qwen3-ASR-0.6B"
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
logger.info("Loading Qwen3 MLX model '%s' ...", model_path)
|
||||
self.session = mlx_qwen3_asr.Session(model_path, dtype=mx.float16)
|
||||
logger.info("Qwen3 MLX model loaded in %.2fs", time.time() - t0)
|
||||
|
||||
self.backend_choice = "qwen3-mlx"
|
||||
self.tokenizer = None
|
||||
|
||||
def transcribe(self, audio):
|
||||
pass # all work happens in the online processor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Online processor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3MLXOnlineProcessor:
|
||||
"""Batch-based processor that accumulates audio and periodically calls
|
||||
``session.transcribe()`` on the full buffer.
|
||||
|
||||
Uses LocalAgreement-style diffing (HypothesisBuffer) to commit stable
|
||||
words across consecutive inferences, exactly like the PyTorch Qwen3
|
||||
backend with ``OnlineASRProcessor``.
|
||||
|
||||
Lifecycle (called by ``AudioProcessor.transcription_processor``):
|
||||
|
||||
insert_audio_chunk(pcm, time) -> process_iter() -> get_buffer()
|
||||
... repeat ...
|
||||
start_silence() / end_silence()
|
||||
finish()
|
||||
"""
|
||||
|
||||
SAMPLING_RATE = 16_000
|
||||
|
||||
def __init__(self, asr: Qwen3MLXASR, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
|
||||
self._session = asr.session
|
||||
lan = asr.original_language
|
||||
self._language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, "English") if lan else None
|
||||
|
||||
# Audio accumulation
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self._buffer_time_offset: float = 0.0 # absolute time of audio_buffer[0]
|
||||
|
||||
# Throttle: minimum new audio (in samples) before re-running inference
|
||||
self._min_new_samples: int = int(1.0 * self.SAMPLING_RATE) # 1 second
|
||||
self._samples_since_last_inference: int = 0
|
||||
|
||||
# Buffer trimming — keep buffer short for fast re-transcription.
|
||||
# The model produces ~0.2x RTF, so 15s buffer = ~3s per call.
|
||||
self._max_buffer_sec: float = 15.0
|
||||
self._trim_sec: float = 10.0 # keep this many seconds after trimming
|
||||
|
||||
# HypothesisBuffer for LocalAgreement diffing
|
||||
self._committed: List[ASRToken] = []
|
||||
self._prev_tokens: List[ASRToken] = [] # previous hypothesis (buffer role)
|
||||
self._last_committed_time: float = 0.0
|
||||
|
||||
# Global time tracking
|
||||
self._global_time_offset: float = 0.0 # extra offset from silences
|
||||
|
||||
# -- audio ingestion --
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
self.end = audio_stream_end_time
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
self._samples_since_last_inference += len(audio)
|
||||
|
||||
# -- batch transcription --
|
||||
|
||||
def _transcribe_buffer(self) -> List[ASRToken]:
|
||||
"""Run batch transcription on the full audio buffer and return tokens."""
|
||||
if len(self.audio_buffer) < 400: # too short for meaningful transcription
|
||||
return []
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
result = self._session.transcribe(
|
||||
self.audio_buffer,
|
||||
language=self._language,
|
||||
return_timestamps=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[qwen3-mlx] transcribe error: %s", e, exc_info=True)
|
||||
return []
|
||||
dur = time.time() - t0
|
||||
audio_dur = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
logger.debug(
|
||||
"[qwen3-mlx] transcribed %.1fs audio in %.2fs (%.2fx RTF)",
|
||||
audio_dur, dur, dur / max(audio_dur, 0.01),
|
||||
)
|
||||
|
||||
text = (result.text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Build tokens from segments (word-level timestamps)
|
||||
tokens: List[ASRToken] = []
|
||||
if result.segments:
|
||||
for i, seg in enumerate(result.segments):
|
||||
word = seg["text"]
|
||||
start = self._buffer_time_offset + seg["start"]
|
||||
end = self._buffer_time_offset + seg["end"]
|
||||
label = word if i == 0 else " " + word
|
||||
tokens.append(ASRToken(start=start, end=end, text=label))
|
||||
else:
|
||||
# Fallback: estimate timestamps from word count
|
||||
words = text.split()
|
||||
step = audio_dur / max(len(words), 1)
|
||||
for i, w in enumerate(words):
|
||||
t_start = self._buffer_time_offset + i * step
|
||||
t_end = self._buffer_time_offset + (i + 1) * step
|
||||
label = w if i == 0 else " " + w
|
||||
tokens.append(ASRToken(start=t_start, end=t_end, text=label))
|
||||
|
||||
return tokens
|
||||
|
||||
def _local_agreement(self, new_tokens: List[ASRToken]) -> List[ASRToken]:
|
||||
"""LocalAgreement diffing: commit the longest common prefix between
|
||||
the previous hypothesis (``self._prev_tokens``) and the new tokens.
|
||||
|
||||
Before comparing, strips tokens that correspond to already-committed
|
||||
audio (i.e., tokens whose start time is before ``_last_committed_time``).
|
||||
Also deduplicates boundary tokens (ngram matching) to avoid re-committing
|
||||
the tail of the previous committed output.
|
||||
|
||||
Returns the newly committed tokens.
|
||||
"""
|
||||
# Step 1: Only keep tokens that are roughly "new" (after last committed time)
|
||||
fresh_tokens = [
|
||||
t for t in new_tokens
|
||||
if t.start > self._last_committed_time - 0.1
|
||||
]
|
||||
|
||||
# Step 2: Remove duplicates at the boundary with committed tokens
|
||||
# (like HypothesisBuffer.insert's ngram dedup)
|
||||
if fresh_tokens and self._committed:
|
||||
max_ngram = min(len(self._committed), len(fresh_tokens), 5)
|
||||
for n in range(1, max_ngram + 1):
|
||||
committed_ngram = " ".join(
|
||||
t.text.strip() for t in self._committed[-n:]
|
||||
)
|
||||
fresh_ngram = " ".join(
|
||||
t.text.strip() for t in fresh_tokens[:n]
|
||||
)
|
||||
if committed_ngram == fresh_ngram:
|
||||
fresh_tokens = fresh_tokens[n:]
|
||||
break
|
||||
|
||||
# Step 3: LocalAgreement -- longest common prefix between prev and fresh
|
||||
committed: List[ASRToken] = []
|
||||
prev = self._prev_tokens
|
||||
i = 0
|
||||
j = 0
|
||||
|
||||
while i < len(fresh_tokens) and j < len(prev):
|
||||
if fresh_tokens[i].text.strip() == prev[j].text.strip():
|
||||
# Agreement: commit this token (use the new token's timestamps)
|
||||
committed.append(fresh_tokens[i])
|
||||
i += 1
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# The remaining fresh tokens become the new "previous hypothesis"
|
||||
self._prev_tokens = fresh_tokens[i:] if i < len(fresh_tokens) else []
|
||||
return committed
|
||||
|
||||
def _trim_buffer_if_needed(self):
|
||||
"""Trim the audio buffer if it exceeds max_buffer_sec.
|
||||
|
||||
Keeps the last ``_trim_sec`` seconds of audio. Also adjusts
|
||||
committed token tracking and buffer_time_offset.
|
||||
"""
|
||||
buffer_dur = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if buffer_dur <= self._max_buffer_sec:
|
||||
return
|
||||
|
||||
keep_sec = self._trim_sec
|
||||
keep_samples = int(keep_sec * self.SAMPLING_RATE)
|
||||
cut_samples = len(self.audio_buffer) - keep_samples
|
||||
if cut_samples <= 0:
|
||||
return
|
||||
|
||||
cut_sec = cut_samples / self.SAMPLING_RATE
|
||||
self.audio_buffer = self.audio_buffer[cut_samples:]
|
||||
self._buffer_time_offset += cut_sec
|
||||
|
||||
# Remove committed tokens that are before the new buffer start
|
||||
self._committed = [
|
||||
t for t in self._committed if t.end > self._buffer_time_offset
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"[qwen3-mlx] trimmed buffer: cut %.1fs, new offset %.1f, buffer %.1fs",
|
||||
cut_sec, self._buffer_time_offset, len(self.audio_buffer) / self.SAMPLING_RATE,
|
||||
)
|
||||
|
||||
# -- interface methods --
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""Process the current audio buffer.
|
||||
|
||||
Throttles inference to at least 1s of new audio between calls.
|
||||
Returns (newly_committed_tokens, audio_processed_upto_time).
|
||||
"""
|
||||
try:
|
||||
# Throttle: skip if not enough new audio since last inference
|
||||
if (not is_last
|
||||
and self._samples_since_last_inference < self._min_new_samples):
|
||||
return [], self.end
|
||||
|
||||
self._samples_since_last_inference = 0
|
||||
|
||||
# Trim buffer if too long
|
||||
self._trim_buffer_if_needed()
|
||||
|
||||
# Run batch transcription
|
||||
new_tokens = self._transcribe_buffer()
|
||||
|
||||
# LocalAgreement diffing
|
||||
committed = self._local_agreement(new_tokens)
|
||||
|
||||
if committed:
|
||||
self._committed.extend(committed)
|
||||
self._last_committed_time = committed[-1].end
|
||||
|
||||
return committed, self.end
|
||||
except Exception as e:
|
||||
logger.warning("[qwen3-mlx] process_iter error: %s", e, exc_info=True)
|
||||
return [], self.end
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
"""Return the unconfirmed text (the tail of the last hypothesis
|
||||
that was not committed by LocalAgreement)."""
|
||||
if not self._prev_tokens:
|
||||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
text = "".join(t.text for t in self._prev_tokens)
|
||||
start = self._prev_tokens[0].start
|
||||
end = self._prev_tokens[-1].end
|
||||
return Transcript(start=start, end=end, text=text)
|
||||
|
||||
def _flush_all(self) -> List[ASRToken]:
|
||||
"""Force a final transcription and commit all remaining words."""
|
||||
# Run one last transcription on the full buffer
|
||||
self._samples_since_last_inference = self._min_new_samples # bypass throttle
|
||||
new_tokens = self._transcribe_buffer()
|
||||
|
||||
# Commit everything: first the agreed prefix, then the remainder
|
||||
committed = self._local_agreement(new_tokens)
|
||||
|
||||
# Also commit any remaining buffer tokens
|
||||
remaining = self._prev_tokens
|
||||
self._prev_tokens = []
|
||||
|
||||
all_new = committed + remaining
|
||||
if all_new:
|
||||
self._committed.extend(all_new)
|
||||
self._last_committed_time = all_new[-1].end
|
||||
|
||||
return all_new
|
||||
|
||||
def _reset_for_new_utterance(self):
|
||||
"""Reset buffers for a new utterance, preserving time continuity."""
|
||||
new_offset = self._buffer_time_offset + len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
saved_end = self.end
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self._buffer_time_offset = new_offset
|
||||
self._samples_since_last_inference = 0
|
||||
self._committed = []
|
||||
self._prev_tokens = []
|
||||
|
||||
self.end = saved_end
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush pending words when silence starts.
|
||||
|
||||
Unlike other backends, does NOT reset the audio buffer — the model
|
||||
produces better results re-transcribing the full accumulated audio.
|
||||
Buffer trimming at 30s handles memory naturally.
|
||||
"""
|
||||
words = self._flush_all()
|
||||
logger.info("[qwen3-mlx] start_silence: flushed %d words", len(words))
|
||||
return words, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self._global_time_offset += silence_duration
|
||||
self.end += silence_duration
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
self.start_silence()
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
pass
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
words = self._flush_all()
|
||||
logger.info("[qwen3-mlx] finish: flushed %d words", len(words))
|
||||
return words, self.end
|
||||
790
whisperlivekit/qwen3_simul_kv.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""
|
||||
Qwen3-ASR SimulStreaming with KV cache reuse.
|
||||
|
||||
This is an optimized version of qwen3_simul.py that reuses the KV cache
|
||||
across inference calls, avoiding redundant prefill of prompt + old audio.
|
||||
|
||||
Architecture:
|
||||
1. First call: full prefill (prompt + audio tokens), greedy decode with
|
||||
alignment-head stopping, save KV cache + generated tokens
|
||||
2. Subsequent calls: invalidate KV for old audio suffix, prefill only
|
||||
new audio tokens, continue decoding from saved state
|
||||
3. Audio encoder caching: reuse embeddings for stable attention windows
|
||||
|
||||
This gives ~3-5x speedup over the original generate()-based approach.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import DynamicCache
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3SimulKVConfig:
|
||||
"""Configuration for Qwen3 SimulStreaming with KV cache."""
|
||||
model_id: str = "Qwen/Qwen3-ASR-1.7B"
|
||||
alignment_heads_path: Optional[str] = None
|
||||
language: str = "auto"
|
||||
border_fraction: float = 0.20
|
||||
rewind_fraction: float = 0.12
|
||||
audio_min_len: float = 0.5
|
||||
audio_max_len: float = 30.0
|
||||
max_context_tokens: int = 20
|
||||
init_prompt: Optional[str] = None
|
||||
max_alignment_heads: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AudioEmbedCache:
|
||||
"""Cache for audio encoder outputs."""
|
||||
encoded_samples: int = 0
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
encoded_mel_frames: int = 0
|
||||
stable_tokens: int = 0
|
||||
|
||||
def reset(self):
|
||||
self.encoded_samples = 0
|
||||
self.embeddings = None
|
||||
self.encoded_mel_frames = 0
|
||||
self.stable_tokens = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3SimulKVState:
|
||||
"""Per-session mutable state with KV cache."""
|
||||
# Audio
|
||||
audio_buffer: np.ndarray = field(
|
||||
default_factory=lambda: np.array([], dtype=np.float32)
|
||||
)
|
||||
cumulative_time_offset: float = 0.0
|
||||
global_time_offset: float = 0.0
|
||||
speaker: int = -1
|
||||
|
||||
# KV cache state
|
||||
kv_cache: Optional[DynamicCache] = None
|
||||
kv_seq_len: int = 0 # sequence length when KV was saved
|
||||
prompt_token_count: int = 0 # tokens before audio (system prompt etc)
|
||||
audio_token_count: int = 0 # audio tokens in the cached KV
|
||||
generated_token_ids: List[int] = field(default_factory=list)
|
||||
|
||||
# Alignment tracking
|
||||
last_attend_frame: int = -15
|
||||
committed_text: str = ""
|
||||
committed_word_count: int = 0
|
||||
committed_token_ids: List[int] = field(default_factory=list)
|
||||
|
||||
# Tracking
|
||||
first_timestamp: Optional[float] = None
|
||||
detected_language: Optional[str] = None
|
||||
last_infer_samples: int = 0
|
||||
|
||||
# Audio embedding cache
|
||||
audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache)
|
||||
|
||||
def reset_kv(self):
|
||||
"""Reset KV cache (e.g., when audio is trimmed from front)."""
|
||||
self.kv_cache = None
|
||||
self.kv_seq_len = 0
|
||||
self.prompt_token_count = 0
|
||||
self.audio_token_count = 0
|
||||
self.generated_token_ids = []
|
||||
# Reset alignment tracking — old frame references are invalid
|
||||
# after audio is trimmed from the front
|
||||
self.last_attend_frame = -15
|
||||
|
||||
|
||||
class Qwen3SimulKVASR:
|
||||
"""
|
||||
Shared backend for Qwen3-ASR SimulStreaming with KV cache reuse.
|
||||
"""
|
||||
|
||||
sep = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = None,
|
||||
model_dir: str = None,
|
||||
lan: str = "auto",
|
||||
alignment_heads_path: Optional[str] = None,
|
||||
border_fraction: float = 0.15,
|
||||
min_chunk_size: float = 0.1,
|
||||
warmup_file: Optional[str] = None,
|
||||
model_cache_dir: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
direct_english_translation: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
self.warmup_file = warmup_file
|
||||
|
||||
self.cfg = Qwen3SimulKVConfig(
|
||||
language=lan,
|
||||
alignment_heads_path=alignment_heads_path,
|
||||
border_fraction=border_fraction,
|
||||
)
|
||||
|
||||
self._load_model(model_size, model_dir, model_cache_dir, model_path)
|
||||
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
|
||||
|
||||
# Pre-compute heads by layer for efficient hook installation
|
||||
self.heads_by_layer = {}
|
||||
for layer_idx, head_idx in self.alignment_heads:
|
||||
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
|
||||
|
||||
if warmup_file:
|
||||
from whisperlivekit.warmup import load_file
|
||||
audio = load_file(warmup_file)
|
||||
if audio is not None:
|
||||
self._warmup(audio)
|
||||
|
||||
def _load_model(self, model_size, model_dir, model_cache_dir, model_path):
|
||||
from whisperlivekit.qwen3_asr import QWEN3_MODEL_MAPPING, _patch_transformers_compat
|
||||
_patch_transformers_compat()
|
||||
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
|
||||
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
|
||||
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
|
||||
|
||||
if model_dir:
|
||||
model_id = model_dir
|
||||
elif model_path:
|
||||
model_id = model_path
|
||||
elif model_size:
|
||||
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
|
||||
else:
|
||||
model_id = "Qwen/Qwen3-ASR-1.7B"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
dtype, device = torch.bfloat16, "cuda:0"
|
||||
else:
|
||||
dtype, device = torch.float32, "cpu"
|
||||
|
||||
logger.info("Loading Qwen3-ASR for SimulStreaming+KV: %s", model_id)
|
||||
self.model = AutoModel.from_pretrained(model_id, dtype=dtype, device_map=device)
|
||||
self.model.eval()
|
||||
self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
|
||||
|
||||
thinker = self.model.thinker
|
||||
text_config = thinker.config.text_config
|
||||
self.num_layers = text_config.num_hidden_layers
|
||||
self.num_heads = text_config.num_attention_heads
|
||||
self.num_kv_heads = text_config.num_key_value_heads
|
||||
self.audio_token_id = thinker.config.audio_token_id
|
||||
self.device = next(self.model.parameters()).device
|
||||
self.dtype = next(self.model.parameters()).dtype
|
||||
self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("<asr_text>")
|
||||
|
||||
# EOS tokens
|
||||
self.eos_ids = {151645, 151643}
|
||||
if self.processor.tokenizer.eos_token_id is not None:
|
||||
self.eos_ids.add(self.processor.tokenizer.eos_token_id)
|
||||
|
||||
logger.info(
|
||||
"Qwen3-ASR loaded: %d layers x %d heads, device=%s",
|
||||
self.num_layers, self.num_heads, self.device,
|
||||
)
|
||||
|
||||
def _load_alignment_heads(self, path):
|
||||
max_heads = self.cfg.max_alignment_heads
|
||||
if path and Path(path).exists():
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
|
||||
heads = all_heads[:max_heads]
|
||||
logger.info("Loaded top %d alignment heads from %s", len(heads), path)
|
||||
return heads
|
||||
default_heads = []
|
||||
start_layer = self.num_layers * 3 // 4
|
||||
for layer in range(start_layer, self.num_layers):
|
||||
for head in range(self.num_heads):
|
||||
default_heads.append((layer, head))
|
||||
logger.warning("No alignment heads file. Using %d default heads.", len(default_heads))
|
||||
return default_heads[:max_heads]
|
||||
|
||||
def _warmup(self, audio):
|
||||
try:
|
||||
audio = audio[:SAMPLE_RATE * 2]
|
||||
msgs = [{"role": "system", "content": ""}, {"role": "user", "content": [{"type": "audio", "audio": ""}]}]
|
||||
text_prompt = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
|
||||
inputs = self.processor(text=[text_prompt], audio=[audio], return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(self.device).to(self.dtype)
|
||||
with torch.inference_mode():
|
||||
self.model.thinker.generate(**inputs, max_new_tokens=5, do_sample=False)
|
||||
logger.info("Warmup complete")
|
||||
except Exception as e:
|
||||
logger.warning("Warmup failed: %s", e)
|
||||
|
||||
def transcribe(self, audio):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3SimulKVOnlineProcessor:
|
||||
"""
|
||||
Per-session online processor with KV cache reuse.
|
||||
|
||||
Key optimization: instead of calling generate() each time (which does
|
||||
full prefill), we maintain a DynamicCache and do incremental prefill
|
||||
+ manual greedy decoding with alignment head hooks.
|
||||
"""
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
def __init__(self, asr: Qwen3SimulKVASR, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer: List[ASRToken] = []
|
||||
self.state = Qwen3SimulKVState()
|
||||
self._build_prompt_template()
|
||||
|
||||
def _build_prompt_template(self):
|
||||
from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE
|
||||
msgs = [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
|
||||
]
|
||||
self._base_prompt = self.asr.processor.apply_chat_template(
|
||||
msgs, add_generation_prompt=True, tokenize=False,
|
||||
)
|
||||
lan = self.asr.cfg.language
|
||||
if lan and lan != "auto":
|
||||
lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
|
||||
self._base_prompt += f"language {lang_name}<asr_text>"
|
||||
|
||||
@property
|
||||
def speaker(self):
|
||||
return self.state.speaker
|
||||
|
||||
@speaker.setter
|
||||
def speaker(self, value):
|
||||
self.state.speaker = value
|
||||
|
||||
@property
|
||||
def global_time_offset(self):
|
||||
return self.state.global_time_offset
|
||||
|
||||
@global_time_offset.setter
|
||||
def global_time_offset(self, value):
|
||||
self.state.global_time_offset = value
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
self.end = audio_stream_end_time
|
||||
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
|
||||
|
||||
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
|
||||
if len(self.state.audio_buffer) > max_samples:
|
||||
trim = len(self.state.audio_buffer) - max_samples
|
||||
self.state.audio_buffer = self.state.audio_buffer[trim:]
|
||||
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
|
||||
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
|
||||
self.state.audio_cache.reset()
|
||||
self.state.reset_kv() # Must invalidate KV when audio is trimmed
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
all_tokens = []
|
||||
for _ in range(5):
|
||||
tokens, _ = self.process_iter(is_last=True)
|
||||
if not tokens:
|
||||
break
|
||||
all_tokens.extend(tokens)
|
||||
return all_tokens, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_len > 0:
|
||||
self.state.audio_buffer = np.append(
|
||||
self.state.audio_buffer, np.zeros(gap_len, dtype=np.float32),
|
||||
)
|
||||
else:
|
||||
self.state = Qwen3SimulKVState()
|
||||
self.state.global_time_offset = silence_duration + offset
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.state = Qwen3SimulKVState()
|
||||
self.state.speaker = change_speaker.speaker
|
||||
self.state.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
return Transcript.from_tokens(tokens=self.buffer, sep='')
|
||||
|
||||
def _encode_audio(self) -> Tuple[torch.Tensor, int]:
|
||||
"""Encode full audio buffer, with caching for stable windows."""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
feat_out = asr.processor.feature_extractor(
|
||||
[state.audio_buffer], sampling_rate=16000,
|
||||
padding=True, truncation=False,
|
||||
return_attention_mask=True, return_tensors="pt",
|
||||
)
|
||||
input_features = feat_out["input_features"].to(asr.device).to(asr.dtype)
|
||||
feature_attention_mask = feat_out["attention_mask"].to(asr.device)
|
||||
total_mel_frames = feature_attention_mask.sum().item()
|
||||
total_audio_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(total_mel_frames),
|
||||
).item()
|
||||
|
||||
cache = state.audio_cache
|
||||
audio_cfg = asr.model.thinker.audio_tower.config
|
||||
n_window_infer = getattr(audio_cfg, "n_window_infer", 400)
|
||||
n_complete_windows = total_mel_frames // n_window_infer
|
||||
|
||||
if n_complete_windows <= 0 or cache.embeddings is None:
|
||||
# Full encode
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
if audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds[0]
|
||||
stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
|
||||
stable_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(stable_mel),
|
||||
).item() if stable_mel > 0 else 0
|
||||
else:
|
||||
stable_mel = n_complete_windows * n_window_infer
|
||||
stable_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(stable_mel),
|
||||
).item()
|
||||
|
||||
if cache.stable_tokens > 0 and cache.stable_tokens <= stable_tokens:
|
||||
cached_prefix = cache.embeddings[:stable_tokens] if cache.embeddings.dim() == 2 else cache.embeddings[0, :stable_tokens]
|
||||
tail_features = input_features[:, :, stable_mel:]
|
||||
tail_mel_frames = total_mel_frames - stable_mel
|
||||
if tail_mel_frames > 0:
|
||||
tail_mask = torch.ones(
|
||||
(1, tail_features.shape[2]),
|
||||
dtype=feature_attention_mask.dtype,
|
||||
device=feature_attention_mask.device,
|
||||
)
|
||||
tail_embeds = asr.model.thinker.get_audio_features(
|
||||
tail_features, feature_attention_mask=tail_mask,
|
||||
)
|
||||
if tail_embeds.dim() == 3:
|
||||
tail_embeds = tail_embeds[0]
|
||||
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
|
||||
else:
|
||||
audio_embeds = cached_prefix
|
||||
else:
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
if audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds[0]
|
||||
|
||||
# Update cache
|
||||
cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0]
|
||||
cache.encoded_samples = len(state.audio_buffer)
|
||||
cache.encoded_mel_frames = total_mel_frames
|
||||
stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
|
||||
cache.stable_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(stable_mel_final),
|
||||
).item() if stable_mel_final > 0 else 0
|
||||
|
||||
return audio_embeds, total_audio_tokens
|
||||
|
||||
def _build_full_inputs(self, audio_embeds: torch.Tensor) -> dict:
|
||||
"""Build full input embeddings from prompt + audio embeddings + context."""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
thinker = asr.model.thinker
|
||||
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens(
|
||||
[self._base_prompt], iter([n_audio_tokens]),
|
||||
)[0]
|
||||
text_ids = asr.processor.tokenizer(
|
||||
[prompt_with_placeholders], return_tensors="pt", padding=True,
|
||||
)
|
||||
input_ids = text_ids["input_ids"].to(asr.device)
|
||||
attention_mask = text_ids.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(asr.device)
|
||||
|
||||
# Append committed context tokens
|
||||
if state.committed_token_ids:
|
||||
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||
ctx_ids = torch.tensor([ctx], dtype=input_ids.dtype, device=input_ids.device)
|
||||
input_ids = torch.cat([input_ids, ctx_ids], dim=1)
|
||||
if attention_mask is not None:
|
||||
ctx_mask = torch.ones_like(ctx_ids)
|
||||
attention_mask = torch.cat([attention_mask, ctx_mask], dim=1)
|
||||
|
||||
# Build inputs_embeds
|
||||
inputs_embeds = thinker.get_input_embeddings()(input_ids)
|
||||
audio_mask = (input_ids == asr.audio_token_id)
|
||||
n_placeholders = audio_mask.sum().item()
|
||||
|
||||
if n_placeholders != n_audio_tokens:
|
||||
logger.warning("Audio token mismatch: %d vs %d", n_placeholders, n_audio_tokens)
|
||||
return None
|
||||
|
||||
audio_embeds_cast = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expand_mask, audio_embeds_cast)
|
||||
|
||||
# Find audio token range
|
||||
audio_positions = audio_mask[0].nonzero(as_tuple=True)[0]
|
||||
audio_start = audio_positions[0].item()
|
||||
audio_end = audio_positions[-1].item() + 1
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"audio_start": audio_start,
|
||||
"audio_end": audio_end,
|
||||
"n_audio_tokens": n_audio_tokens,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
|
||||
if audio_duration < self.asr.cfg.audio_min_len:
|
||||
return [], self.end
|
||||
|
||||
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
|
||||
min_new_seconds = 1.0
|
||||
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
|
||||
return [], self.end
|
||||
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
try:
|
||||
timestamped_words = self._infer(is_last)
|
||||
except Exception as e:
|
||||
logger.exception("Inference error: %s", e)
|
||||
self.state.reset_kv()
|
||||
return [], self.end
|
||||
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
def _infer(self, is_last: bool) -> List[ASRToken]:
|
||||
"""Run inference with KV cache reuse and alignment-head stopping."""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
thinker = asr.model.thinker
|
||||
|
||||
# Step 1: Encode audio (with caching)
|
||||
audio_embeds, n_audio_tokens_total = self._encode_audio()
|
||||
|
||||
# Step 2: Build full inputs
|
||||
full_inputs = self._build_full_inputs(audio_embeds)
|
||||
if full_inputs is None:
|
||||
state.reset_kv()
|
||||
return []
|
||||
|
||||
input_ids = full_inputs["input_ids"]
|
||||
inputs_embeds = full_inputs["inputs_embeds"]
|
||||
attention_mask = full_inputs["attention_mask"]
|
||||
audio_start = full_inputs["audio_start"]
|
||||
audio_end = full_inputs["audio_end"]
|
||||
n_audio_tokens = full_inputs["n_audio_tokens"]
|
||||
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
|
||||
|
||||
# Step 3: Full prefill (we always re-prefill since audio tokens change)
|
||||
# Future optimization: partial prefill when only tail audio changes
|
||||
out = thinker(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
kv_cache = out.past_key_values
|
||||
prompt_len = input_ids.shape[1]
|
||||
|
||||
# Step 4: Greedy decode with alignment head stopping
|
||||
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
|
||||
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
|
||||
last_attend_frame = state.last_attend_frame
|
||||
|
||||
# Install hooks for alignment head attention extraction
|
||||
decoder_layers = thinker.model.layers
|
||||
num_kv_heads = asr.num_kv_heads
|
||||
num_heads = asr.num_heads
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
|
||||
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import apply_rotary_pos_emb
|
||||
|
||||
per_step_frames: List[List[int]] = []
|
||||
current_step_frames: List[int] = []
|
||||
hooks = []
|
||||
|
||||
def _make_attn_hook(layer_idx):
|
||||
head_indices = asr.heads_by_layer[layer_idx]
|
||||
def hook_fn(module, args, kwargs, output):
|
||||
hidden_states = kwargs.get('hidden_states')
|
||||
if hidden_states is None:
|
||||
hidden_states = args[0] if args else None
|
||||
if hidden_states is None or hidden_states.shape[1] != 1:
|
||||
return
|
||||
position_embeddings = kwargs.get('position_embeddings')
|
||||
if position_embeddings is None and len(args) > 1:
|
||||
position_embeddings = args[1]
|
||||
past_kv = kwargs.get('past_key_values')
|
||||
if position_embeddings is None or past_kv is None:
|
||||
return
|
||||
|
||||
hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim)
|
||||
q = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
cos, sin = position_embeddings
|
||||
q, _ = apply_rotary_pos_emb(q, q, cos, sin)
|
||||
|
||||
cache_layer = past_kv.layers[module.layer_idx]
|
||||
k = cache_layer.keys
|
||||
if k is None or audio_end > k.shape[2]:
|
||||
return
|
||||
|
||||
for h_idx in head_indices:
|
||||
if h_idx >= q.shape[1]:
|
||||
continue
|
||||
kv_h_idx = h_idx // gqa_ratio
|
||||
q_h = q[0, h_idx, 0]
|
||||
k_audio = k[0, kv_h_idx, audio_start:audio_end]
|
||||
scores = torch.matmul(k_audio, q_h)
|
||||
frame = scores.argmax().item()
|
||||
current_step_frames.append(frame)
|
||||
return hook_fn
|
||||
|
||||
for layer_idx in asr.heads_by_layer:
|
||||
if layer_idx < len(decoder_layers):
|
||||
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
|
||||
_make_attn_hook(layer_idx), with_kwargs=True,
|
||||
)
|
||||
hooks.append(h)
|
||||
|
||||
try:
|
||||
# Greedy decoding with alignment-based stopping
|
||||
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
generated_ids = []
|
||||
border_stop_step = None
|
||||
tokens_per_sec = 6
|
||||
if is_last:
|
||||
max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120)
|
||||
else:
|
||||
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
|
||||
max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40)
|
||||
|
||||
for step in range(max_tokens):
|
||||
tid = next_token.item()
|
||||
if tid in asr.eos_ids:
|
||||
break
|
||||
generated_ids.append(tid)
|
||||
|
||||
# Collect alignment frames for this step
|
||||
if current_step_frames:
|
||||
per_step_frames.append(current_step_frames)
|
||||
current_step_frames = []
|
||||
|
||||
# Check stopping criteria (after 3 tokens)
|
||||
if not is_last and len(per_step_frames) >= 3:
|
||||
latest = per_step_frames[-1]
|
||||
if latest:
|
||||
frames_sorted = sorted(latest)
|
||||
attended = frames_sorted[len(frames_sorted) // 2]
|
||||
|
||||
if last_attend_frame - attended > rewind_threshold:
|
||||
border_stop_step = max(0, len(per_step_frames) - 2)
|
||||
break
|
||||
|
||||
last_attend_frame = attended
|
||||
|
||||
if (n_audio_tokens - attended) <= border_threshold:
|
||||
border_stop_step = len(per_step_frames) - 1
|
||||
break
|
||||
|
||||
# Next token
|
||||
out = thinker(
|
||||
input_ids=next_token,
|
||||
past_key_values=kv_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
kv_cache = out.past_key_values
|
||||
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
|
||||
# Flush remaining frames
|
||||
if current_step_frames:
|
||||
per_step_frames.append(current_step_frames)
|
||||
finally:
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
state.last_attend_frame = last_attend_frame
|
||||
|
||||
if not generated_ids:
|
||||
return []
|
||||
|
||||
# Strip metadata prefix (<asr_text> token)
|
||||
all_generated = torch.tensor(generated_ids, device=asr.device)
|
||||
num_gen = len(generated_ids)
|
||||
asr_text_id = asr.asr_text_token_id
|
||||
metadata_offset = 0
|
||||
for i in range(min(num_gen, 10)):
|
||||
if generated_ids[i] == asr_text_id:
|
||||
if state.detected_language is None and i > 0:
|
||||
from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE
|
||||
prefix_text = asr.processor.tokenizer.decode(
|
||||
generated_ids[:i], skip_special_tokens=True,
|
||||
).strip()
|
||||
parts = prefix_text.split()
|
||||
if len(parts) >= 2:
|
||||
lang_name = parts[-1]
|
||||
if lang_name.lower() != "none":
|
||||
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
|
||||
lang_name, lang_name.lower(),
|
||||
)
|
||||
metadata_offset = i + 1
|
||||
break
|
||||
|
||||
if metadata_offset > 0:
|
||||
generated_ids = generated_ids[metadata_offset:]
|
||||
num_gen -= metadata_offset
|
||||
per_step_frames = per_step_frames[metadata_offset:]
|
||||
|
||||
if num_gen <= 0:
|
||||
return []
|
||||
|
||||
# Determine emit count
|
||||
if border_stop_step is not None:
|
||||
emit_up_to = min(border_stop_step, num_gen)
|
||||
else:
|
||||
emit_up_to = num_gen
|
||||
|
||||
emitted_ids = generated_ids[:emit_up_to]
|
||||
if not emitted_ids:
|
||||
return []
|
||||
|
||||
# Build timestamped words
|
||||
words = self._build_timestamped_words(
|
||||
emitted_ids, per_step_frames, emit_up_to,
|
||||
n_audio_tokens, audio_duration,
|
||||
)
|
||||
|
||||
state.committed_word_count += len(words)
|
||||
# Include metadata in committed tokens for context
|
||||
all_emitted = generated_ids[:emit_up_to]
|
||||
if metadata_offset > 0:
|
||||
all_emitted = generated_ids[:emit_up_to] # already stripped
|
||||
state.committed_token_ids.extend(all_emitted)
|
||||
|
||||
return words
|
||||
|
||||
def _build_timestamped_words(
|
||||
self,
|
||||
generated_ids: list,
|
||||
step_frames: List[List[int]],
|
||||
emit_up_to: int,
|
||||
n_audio_tokens: int,
|
||||
audio_duration: float,
|
||||
) -> List[ASRToken]:
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
|
||||
per_token_frame = []
|
||||
for step in range(emit_up_to):
|
||||
if step < len(step_frames) and step_frames[step]:
|
||||
frames = sorted(step_frames[step])
|
||||
per_token_frame.append(frames[len(frames) // 2])
|
||||
else:
|
||||
per_token_frame.append(None)
|
||||
|
||||
tokenizer = asr.processor.tokenizer
|
||||
full_text = tokenizer.decode(generated_ids[:emit_up_to], skip_special_tokens=True)
|
||||
text_words = full_text.split()
|
||||
|
||||
all_frames = [f for f in per_token_frame if f is not None]
|
||||
words = []
|
||||
for wi, word in enumerate(text_words):
|
||||
if all_frames:
|
||||
frac = wi / max(len(text_words), 1)
|
||||
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
|
||||
frame = all_frames[frame_idx]
|
||||
else:
|
||||
frame = None
|
||||
words.append((word, frame))
|
||||
|
||||
tokens = []
|
||||
for i, (text, frame) in enumerate(words):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
if frame is not None and n_audio_tokens > 0:
|
||||
timestamp = (
|
||||
frame / n_audio_tokens * audio_duration
|
||||
+ state.cumulative_time_offset
|
||||
)
|
||||
else:
|
||||
timestamp = (
|
||||
(i / max(len(words), 1)) * audio_duration
|
||||
+ state.cumulative_time_offset
|
||||
)
|
||||
|
||||
is_very_first_word = (i == 0 and state.committed_word_count == 0)
|
||||
display_text = text if is_very_first_word else " " + text
|
||||
|
||||
token = ASRToken(
|
||||
start=round(timestamp, 2),
|
||||
end=round(timestamp + 0.1, 2),
|
||||
text=display_text,
|
||||
speaker=state.speaker,
|
||||
detected_language=state.detected_language,
|
||||
).with_offset(state.global_time_offset)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
|
||||
try:
|
||||
self.state.audio_buffer = audio[:SAMPLE_RATE]
|
||||
self.process_iter(is_last=True)
|
||||
self.state = Qwen3SimulKVState()
|
||||
except Exception as e:
|
||||
logger.warning("Warmup failed: %s", e)
|
||||
self.state = Qwen3SimulKVState()
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
all_tokens = []
|
||||
for _ in range(5):
|
||||
tokens, _ = self.process_iter(is_last=True)
|
||||
if not tokens:
|
||||
break
|
||||
all_tokens.extend(tokens)
|
||||
return all_tokens, self.end
|
||||