Files
WhisperLiveKit/whisperlivekit/qwen3_mlx_simul.py
Quentin Fuxa 3f233dc36c Fix all ruff lint errors (68 errors → 0)
- Remove unused imports and variables (F401, F841)
- Sort import blocks (I001)
- Split semicolon-separated statements (E702)
- Fix backslash in f-string for Python 3.11 compat (cli.py)
- Remove empty f-strings (F541)
- Add noqa for intentional E402 after sys.path manipulation
2026-03-31 23:02:50 +02:00

760 lines
28 KiB
Python

"""
Qwen3-ASR SimulStreaming (AlignAtt) on MLX for Apple Silicon.
Uses the ``mlx_qwen3_asr`` library for model loading, audio encoding, and
tokenization. Implements the AlignAtt border-distance policy by monkey-
patching ``TextAttention.__call__`` on alignment layers to capture Q (with
RoPE) during autoregressive decode steps, then computing ``Q @ K_audio^T``
from the KV cache to find the most-attended audio frame.
This is the MLX equivalent of ``qwen3_simul.py`` (PyTorch) which uses
``register_forward_hook`` for the same purpose.
"""
import json
import logging
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16_000
# Model size aliases (same as qwen3_mlx_asr.py)
QWEN3_MLX_MODEL_MAPPING = {
"base": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"1.7b": "Qwen/Qwen3-ASR-1.7B",
"0.6b": "Qwen/Qwen3-ASR-0.6B",
}
# Whisper language codes -> Qwen3 canonical language names
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class Qwen3MLXSimulConfig:
language: str = "auto"
alignment_heads_path: Optional[str] = None
border_fraction: float = 0.15
rewind_fraction: float = 0.12
audio_min_len: float = 3.0
audio_max_len: float = 15.0
max_context_tokens: int = 30
max_alignment_heads: int = 20
# ---------------------------------------------------------------------------
# Per-session state
# ---------------------------------------------------------------------------
@dataclass
class _SessionState:
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
last_attend_frame: int = -15
committed_word_count: int = 0
committed_token_ids: List[int] = field(default_factory=list)
detected_language: Optional[str] = None
last_infer_samples: int = 0
# Pending partial word from previous _infer() call.
# When a border stops mid-word (e.g., "Vill" from "Villard"),
# the partial is held here and prepended to the next call's output.
pending_partial: str = ""
pending_partial_start: Optional[float] = None
# Whether the first emitted token of this call is a continuation of the
# previous call's last word (no leading space → subword continuation).
first_emit_is_continuation: bool = False
# ---------------------------------------------------------------------------
# Shared model holder
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingASR:
"""Loads the Qwen3-ASR model via ``mlx_qwen3_asr`` once and keeps it
alive for the lifetime of the server. Shared across sessions."""
sep = ""
SAMPLING_RATE = SAMPLE_RATE
def __init__(
self,
model_size: str = None,
model_dir: str = None,
model_path: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
lora_path: Optional[str] = None,
min_chunk_size: float = 0.1,
direct_english_translation: bool = False,
**kwargs,
):
import mlx.core as mx
import mlx_qwen3_asr
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3MLXSimulConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
# Resolve model path
resolved = model_dir or model_path
if not resolved:
size = (model_size or "base").lower()
if "/" in size or size.startswith("."):
resolved = size
else:
resolved = QWEN3_MLX_MODEL_MAPPING.get(size, "Qwen/Qwen3-ASR-0.6B")
t0 = time.time()
logger.info("Loading Qwen3-ASR MLX model '%s' for SimulStreaming ...", resolved)
self.model, self._config = mlx_qwen3_asr.load_model(resolved, dtype=mx.float16)
logger.info("Model loaded in %.2fs", time.time() - t0)
# Tokenizer
tok_path = getattr(self.model, "_resolved_model_path", None) or resolved
self.tokenizer = mlx_qwen3_asr.tokenizer.Tokenizer(str(tok_path))
# Architecture info
text_cfg = self._config.text_config
self.num_layers = text_cfg.num_hidden_layers
self.num_heads = text_cfg.num_attention_heads
self.num_kv_heads = text_cfg.num_key_value_heads
self.head_dim = text_cfg.head_dim
self.gqa_ratio = self.num_heads // self.num_kv_heads
self.audio_token_id = self._config.audio_token_id
logger.info(
"Qwen3-ASR arch: %d layers x %d heads (%d kv), head_dim=%d, GQA=%d",
self.num_layers, self.num_heads, self.num_kv_heads,
self.head_dim, self.gqa_ratio,
)
# Alignment heads
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
self.heads_by_layer = {}
for layer_idx, head_idx in self.alignment_heads:
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
self.backend_choice = "qwen3-mlx-simul"
# Warmup
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
self._warmup(audio)
def _load_alignment_heads(
self, path: Optional[str],
) -> List[Tuple[int, int]]:
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info(
"Loaded top %d alignment heads from %s (of %d total)",
len(heads), path, len(all_heads),
)
return heads
# Default heuristic: last quarter of layers, all heads
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning(
"No alignment heads file. Using default heuristic: "
"%d heads from layers %d-%d.",
len(default_heads), start_layer, self.num_layers - 1,
)
return default_heads[:max_heads]
def _warmup(self, audio: np.ndarray):
import mlx.core as mx
try:
from mlx_qwen3_asr.audio import compute_features
audio = audio[:SAMPLE_RATE * 2]
mel, feat_lens = compute_features(audio)
mel = mel.astype(mx.float16)
audio_features, _ = self.model.audio_tower(mel, feat_lens)
n_audio = int(audio_features.shape[1])
prompt = self.tokenizer.build_prompt_tokens(n_audio, language="English")
input_ids = mx.array([prompt])
positions = mx.arange(input_ids.shape[1])[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
cache = self.model.create_cache()
logits = self.model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
logger.info("Qwen3 MLX SimulStreaming warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
pass # all work in the online processor
# ---------------------------------------------------------------------------
# Attention capture via wrapper replacement
# ---------------------------------------------------------------------------
class _AttnCaptureWrapper:
"""Wraps a TextAttention module to capture alignment scores during decode.
Replaces ``layer.self_attn`` with this wrapper. On decode steps (L=1),
recomputes Q with RoPE, reads cached K from the audio region, computes
``Q @ K_audio^T`` for alignment heads, and stores the argmax frame in
``capture["step_frames"]``.
Python dunder resolution (``__call__``) goes through the *class*, not the
instance, so monkey-patching ``attn.__call__`` on an ``nn.Module`` does
not work. This wrapper class defines its own ``__call__`` and delegates
everything else to the wrapped module via ``__getattr__``.
"""
def __init__(self, original, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture):
# Store in __dict__ directly to avoid triggering __getattr__
self.__dict__["_original"] = original
self.__dict__["_layer_idx"] = layer_idx
self.__dict__["_head_indices"] = head_indices
self.__dict__["_gqa_ratio"] = gqa_ratio
self.__dict__["_audio_start"] = audio_start
self.__dict__["_audio_end"] = audio_end
self.__dict__["_capture"] = capture
def __call__(self, x, cos, sin, mask=None, cache=None, layer_idx=0):
import mlx.core as mx
from mlx_qwen3_asr.mrope import apply_rotary_pos_emb
orig = self.__dict__["_original"]
B, L, _ = x.shape
if L == 1 and cache is not None:
li = self.__dict__["_layer_idx"]
h_indices = self.__dict__["_head_indices"]
gqa = self.__dict__["_gqa_ratio"]
a_start = self.__dict__["_audio_start"]
a_end = self.__dict__["_audio_end"]
cap = self.__dict__["_capture"]
# Recompute Q with RoPE (cheap: single token)
q = orig.q_proj(x)
q = q.reshape(B, L, orig.num_heads, orig.head_dim)
q = orig.q_norm(q)
q = q.transpose(0, 2, 1, 3) # (B, H, 1, D)
q_rope, _ = apply_rotary_pos_emb(q, q, cos, sin)
# K from cache (already has RoPE baked in from cache.update)
k_cached = cache.keys[li]
if k_cached is not None and a_end <= k_cached.shape[2]:
for h_idx in h_indices:
kv_h = h_idx // gqa
q_h = q_rope[0, h_idx, 0] # (head_dim,)
k_audio = k_cached[0, kv_h, a_start:a_end] # (n_audio, D)
scores = k_audio @ q_h # (n_audio,)
frame = int(mx.argmax(scores).item())
cap["step_frames"].append(frame)
return orig(x, cos, sin, mask=mask, cache=cache, layer_idx=layer_idx)
def __getattr__(self, name):
return getattr(self.__dict__["_original"], name)
def _install_alignment_hooks(model, heads_by_layer, gqa_ratio, audio_start, audio_end, capture):
"""Replace ``self_attn`` on alignment layers with capture wrappers.
Returns a list of ``(layer_idx, original_attn)`` for later restoration.
"""
originals = []
for layer_idx, head_indices in heads_by_layer.items():
if layer_idx >= len(model.model.layers):
continue
layer = model.model.layers[layer_idx]
orig_attn = layer.self_attn
wrapper = _AttnCaptureWrapper(
orig_attn, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture,
)
layer.self_attn = wrapper
originals.append((layer_idx, orig_attn))
return originals
def _remove_alignment_hooks(model, originals):
"""Restore original self_attn modules."""
for layer_idx, orig_attn in originals:
model.model.layers[layer_idx].self_attn = orig_attn
# ---------------------------------------------------------------------------
# Per-session online processor
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingOnlineProcessor:
"""Per-session processor implementing AlignAtt on MLX.
Same interface as other online processors:
insert_audio_chunk / process_iter / get_buffer / start_silence /
end_silence / finish / warmup / new_speaker.
"""
SAMPLING_RATE = SAMPLE_RATE
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3MLXSimulStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
self.state = _SessionState()
# -- properties expected by AudioProcessor --
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
# Trim if too long
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
# -- main processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
# Throttle: at least 1s of new audio
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
if not is_last and new_samples < int(1.0 * self.SAMPLING_RATE):
return [], self.end
try:
words = self._infer(is_last)
except Exception as e:
logger.exception("Qwen3 MLX SimulStreaming inference error: %s", e)
return [], self.end
# Update the budget marker after _infer() so the decoder can size its
# generation budget using the real amount of fresh audio.
self.state.last_infer_samples = len(self.state.audio_buffer)
if not words:
return [], self.end
self.buffer = []
return words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run one inference cycle with alignment-head-based stopping."""
import mlx.core as mx
from mlx_qwen3_asr.audio import compute_features
from mlx_qwen3_asr.generate import _detect_repetition
asr = self.asr
state = self.state
model = asr.model
# 1. Encode audio
mel, feat_lens = compute_features(state.audio_buffer)
mel = mel.astype(mx.float16)
audio_features, _ = model.audio_tower(mel, feat_lens)
n_audio_tokens = int(audio_features.shape[1])
mx.eval(audio_features)
if n_audio_tokens == 0:
return []
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# 2. Build prompt tokens
lan = asr.cfg.language
language = None
if lan and lan != "auto":
language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
prompt_tokens = asr.tokenizer.build_prompt_tokens(
n_audio_tokens=n_audio_tokens,
language=language,
)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
prompt_tokens.extend(ctx)
input_ids = mx.array([prompt_tokens])
seq_len = input_ids.shape[1]
# 3. Find audio token range
audio_positions = [
i for i, t in enumerate(prompt_tokens) if t == asr.audio_token_id
]
if not audio_positions:
return []
audio_start = audio_positions[0]
audio_end = audio_positions[-1] + 1
# 4. MRoPE position IDs
positions = mx.arange(seq_len, dtype=mx.int32)[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
# 5. Prefill
cache = model.create_cache(max_seq_len=seq_len + 120)
logits = model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
# 6. Install alignment hooks
capture = {"step_frames": []}
originals = _install_alignment_hooks(
model, asr.heads_by_layer, asr.gqa_ratio,
audio_start, audio_end, capture,
)
# 7. Decode loop with border-distance policy
eos_ids = set(asr.tokenizer.EOS_TOKEN_IDS)
per_step_frames: List[List[int]] = []
last_attend_frame = state.last_attend_frame
border_stop_step: Optional[int] = None
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
# Max tokens: ~6 tokens/sec of speech + margin
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
if is_last:
max_tokens = min(int(audio_duration * 6) + 10, 120)
else:
max_tokens = min(int(max(new_audio_secs, 1.0) * 6) + 5, 40)
token = int(mx.argmax(logits.reshape(-1)).item())
generated = [token]
try:
for step in range(1, max_tokens):
if token in eos_ids:
break
if _detect_repetition(generated):
break
next_ids = mx.array([[token]])
pos_val = seq_len + step - 1
next_pos = mx.array([[[pos_val], [pos_val], [pos_val]]], dtype=mx.int32)
logits = model.step(next_ids, next_pos, cache, validate_input_ids=False)
mx.eval(logits)
token = int(mx.argmax(logits.reshape(-1)).item())
generated.append(token)
# Collect frames from this step
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
capture["step_frames"] = []
# Border-distance check (skip first 3 steps)
if (not is_last
and border_stop_step is None
and len(per_step_frames) >= 3):
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
# Rewind check
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
break
last_attend_frame = attended
# Border check
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
break
# Periodic eval to prevent graph buildup
if step % 8 == 0:
mx.eval(cache.keys[-1])
finally:
_remove_alignment_hooks(model, originals)
# Flush remaining frames
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
state.last_attend_frame = last_attend_frame
# 8. Process generated tokens
# Remove trailing EOS
while generated and generated[-1] in eos_ids:
generated.pop()
num_gen = len(generated)
if num_gen == 0:
return []
raw_text = asr.tokenizer.decode(generated)
logger.info(
"SimulStreaming raw: %d tokens (border_stop=%s), text=%r",
num_gen, border_stop_step, raw_text[:100],
)
# 9. Strip metadata prefix ("language English<asr_text>...")
from mlx_qwen3_asr.tokenizer import parse_asr_output
detected_lang, clean_text = parse_asr_output(
raw_text,
user_language=language,
)
# Find how many tokens to skip for metadata
metadata_offset = 0
asr_text_tokens = asr.tokenizer.encode("<asr_text>")
asr_text_id = asr_text_tokens[0] if asr_text_tokens else None
if asr_text_id is not None:
for i in range(min(num_gen, 10)):
if generated[i] == asr_text_id:
metadata_offset = i + 1
break
if metadata_offset > 0:
generated = generated[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Detect language
if state.detected_language is None and detected_lang and detected_lang != "unknown":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
detected_lang, detected_lang.lower(),
)
logger.info("Auto-detected language: %s", state.detected_language)
# 10. Determine how many tokens to emit
step_frames = [f for f in per_step_frames if f]
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
if emit_up_to <= 0:
return []
emitted_ids = generated[:emit_up_to]
if emit_up_to <= 0:
return []
# 11. Build timestamped words
words = self._build_timestamped_words(
emitted_ids, step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
# Update state
state.committed_word_count += len(words)
state.committed_token_ids.extend(emitted_ids)
return words
def _build_timestamped_words(
self,
generated_ids: List[int],
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
"""Build timestamped ASRToken list from generated tokens and
alignment-head captured frames."""
state = self.state
asr = self.asr
# Per-token attended frame (median of head votes)
per_token_frame: List[Optional[int]] = []
for step_idx in range(emit_up_to):
if step_idx < len(step_frames) and step_frames[step_idx]:
frames = sorted(step_frames[step_idx])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
# Decode full text, split into words
full_text = asr.tokenizer.decode(generated_ids[:emit_up_to])
text_words = full_text.split()
# Map words to frames proportionally
all_frames = [f for f in per_token_frame if f is not None]
word_frame_pairs = []
for wi, word in enumerate(text_words):
if all_frames:
frac = wi / max(len(text_words), 1)
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
word_frame_pairs.append((word, frame))
# Convert to ASRToken
tokens = []
for i, (text, frame) in enumerate(word_frame_pairs):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(word_frame_pairs), 1)) * audio_duration
+ state.cumulative_time_offset
)
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
# -- silence / speaker / lifecycle --
def start_silence(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
gap_silence = np.zeros(gap_len, dtype=np.float32)
self.state.audio_buffer = np.append(
self.state.audio_buffer, gap_silence,
)
else:
self.state = _SessionState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker):
self.process_iter(is_last=True)
self.state = _SessionState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = _SessionState()
logger.info("Qwen3 MLX SimulStreaming processor warmed up")
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = _SessionState()
def finish(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end