5 Commits

Author SHA1 Message Date
Quentin Fuxa
b102e12943 M5 benchmark figures: WER vs RTF scatter, 0.6B+1.7B MLX results 2026-03-15 15:00:00 +01:00
Quentin Fuxa
7aa3b764bd MLX benchmark: 1.7B SimulStreaming on M5 (WER 4.07%, RTF 0.944)
LibriSpeech test-clean, 500 utterances.
1.7B is borderline real-time on M5 (RTF 0.944).
0.6B (3.30% WER, 0.263 RTF) is the practical choice for MacBook.
2026-03-15 14:00:00 +01:00
Quentin Fuxa
a422e604ae MLX benchmark: 0.6B SimulStreaming on M5 MacBook (WER 3.30%, RTF 0.263)
LibriSpeech test-clean, 500 utterances, per-utterance simul-streaming.
AlignAtt border detection with 20 alignment heads.
Platform: Apple M5 32GB (MLX fp16).

benchmark_mlx_simul.py: reusable benchmark script for MLX backends.
2026-03-15 13:00:00 +01:00
Quentin Fuxa
e14b913807 Merge branch 'benchmarks-h100' 2026-03-15 12:00:00 +01:00
Quentin Fuxa
3b7a2fcc87 Add Qwen3-ASR MLX SimulStreaming backend
New backend 'qwen3-mlx-simul' for Apple Silicon: AlignAtt border
detection via monkey-patched cross-attention on MLX Qwen3-ASR.
Supports 0.6B (RTF 0.236 on M5) and 1.7B models.

- qwen3_mlx_simul.py: full streaming implementation with KV cache,
  alignment head attention extraction, border-distance policy
- core.py: register new backend in TranscriptionEngine + online_factory
- parse_args.py: add qwen3-mlx-simul to CLI choices
2026-03-15 11:00:00 +01:00
9 changed files with 12417 additions and 2 deletions

460
benchmark_mlx_simul.py Normal file
View File

@@ -0,0 +1,460 @@
#!/usr/bin/env python3
"""
Benchmark Qwen3-ASR MLX SimulStreaming on LibriSpeech test-clean.
Measures:
- Word Error Rate (WER) via jiwer
- Real-Time Factor (RTF) = total_inference_time / total_audio_duration
- Per-utterance stats
Usage:
# Per-utterance simul-streaming (default)
python benchmark_mlx_simul.py --model-size 0.6b
# Single-shot (batch-like, no streaming chunking)
python benchmark_mlx_simul.py --model-size 0.6b --single-shot
# Quick test with 100 utterances
python benchmark_mlx_simul.py --model-size 0.6b --max-utterances 100
# Chapter-grouped (matching H100 benchmark methodology)
python benchmark_mlx_simul.py --model-size 0.6b --chapter-grouped
"""
import argparse
import json
import logging
import os
import re
import sys
import time
from collections import defaultdict
from pathlib import Path
import numpy as np
import soundfile as sf
from jiwer import wer as compute_wer, cer as compute_cer
# Add WhisperLiveKit to path
WLKIT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(WLKIT_DIR))
from whisperlivekit.qwen3_mlx_simul import (
Qwen3MLXSimulStreamingASR,
Qwen3MLXSimulStreamingOnlineProcessor,
)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("benchmark")
logger.setLevel(logging.INFO)
SAMPLE_RATE = 16_000
# Alignment heads paths
ALIGNMENT_HEADS = {
"0.6b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_0.6B.json"),
"1.7b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_1.7B_v2.json"),
}
def load_librispeech_utterances(data_dir: str, max_utterances: int = 0):
"""Load LibriSpeech utterances: yields (utt_id, audio_np, reference_text, duration_s)."""
data_path = Path(data_dir)
trans_files = sorted(data_path.rglob("*.trans.txt"))
count = 0
for trans_file in trans_files:
chapter_dir = trans_file.parent
with open(trans_file) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(" ", 1)
utt_id = parts[0]
ref_text = parts[1] if len(parts) > 1 else ""
flac_path = chapter_dir / f"{utt_id}.flac"
if not flac_path.exists():
logger.warning("Missing FLAC: %s", flac_path)
continue
audio, sr = sf.read(str(flac_path), dtype="float32")
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
duration = len(audio) / SAMPLE_RATE
yield utt_id, audio, ref_text, duration
count += 1
if max_utterances > 0 and count >= max_utterances:
return
def load_librispeech_chapters(data_dir: str):
"""Load LibriSpeech grouped by speaker-chapter.
Concatenates all utterances within each speaker/chapter into one long audio.
Returns list of (chapter_id, audio_np, reference_text, duration_s).
"""
data_path = Path(data_dir)
trans_files = sorted(data_path.rglob("*.trans.txt"))
chapters = []
for trans_file in trans_files:
chapter_dir = trans_file.parent
chapter_id = chapter_dir.name
speaker_id = chapter_dir.parent.name
full_id = f"{speaker_id}-{chapter_id}"
audios = []
refs = []
with open(trans_file) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(" ", 1)
utt_id = parts[0]
ref_text = parts[1] if len(parts) > 1 else ""
flac_path = chapter_dir / f"{utt_id}.flac"
if not flac_path.exists():
continue
audio, sr = sf.read(str(flac_path), dtype="float32")
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
audios.append(audio)
refs.append(ref_text)
if audios:
# Concatenate with 0.5s silence between utterances
silence = np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)
combined = []
for j, a in enumerate(audios):
if j > 0:
combined.append(silence)
combined.append(a)
combined_audio = np.concatenate(combined)
combined_ref = " ".join(refs)
duration = len(combined_audio) / SAMPLE_RATE
chapters.append((full_id, combined_audio, combined_ref, duration))
return chapters
def transcribe_simul(asr, audio, chunk_seconds=2.0):
"""Transcribe using SimulStreaming with chunked audio feed.
Returns (transcription_text, inference_time_seconds).
"""
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
chunk_size = int(chunk_seconds * SAMPLE_RATE)
total_samples = len(audio)
offset = 0
all_tokens = []
t0 = time.perf_counter()
while offset < total_samples:
end = min(offset + chunk_size, total_samples)
chunk = audio[offset:end]
stream_time = end / SAMPLE_RATE
processor.insert_audio_chunk(chunk, stream_time)
is_last = (end >= total_samples)
tokens, _ = processor.process_iter(is_last=is_last)
if tokens:
all_tokens.extend(tokens)
offset = end
# Final flush
final_tokens, _ = processor.finish()
if final_tokens:
all_tokens.extend(final_tokens)
t1 = time.perf_counter()
inference_time = t1 - t0
text = "".join(t.text for t in all_tokens).strip()
return text, inference_time
def transcribe_single_shot(asr, audio):
"""Transcribe by feeding all audio at once (batch-like).
Returns (transcription_text, inference_time_seconds).
"""
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
t0 = time.perf_counter()
duration = len(audio) / SAMPLE_RATE
processor.insert_audio_chunk(audio, duration)
all_tokens, _ = processor.process_iter(is_last=True)
# Flush
final_tokens, _ = processor.finish()
if final_tokens:
all_tokens.extend(final_tokens)
t1 = time.perf_counter()
inference_time = t1 - t0
text = "".join(t.text for t in all_tokens).strip()
return text, inference_time
def normalize_text(text: str) -> str:
"""Normalize text for WER computation: uppercase, strip punctuation."""
text = text.upper()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def main():
parser = argparse.ArgumentParser(description="Benchmark Qwen3-ASR MLX SimulStreaming")
parser.add_argument("--model-size", default="0.6b", choices=["0.6b", "1.7b"],
help="Model size (default: 0.6b)")
parser.add_argument("--max-utterances", type=int, default=0,
help="Max utterances to process (0=all). Ignored in chapter mode.")
parser.add_argument("--librispeech-dir", default="/tmp/LibriSpeech/test-clean",
help="Path to LibriSpeech test-clean directory")
parser.add_argument("--single-shot", action="store_true",
help="Feed entire audio at once instead of streaming chunks")
parser.add_argument("--chunk-seconds", type=float, default=2.0,
help="Chunk size in seconds for simul-streaming (default: 2.0)")
parser.add_argument("--border-fraction", type=float, default=0.25,
help="Border fraction for AlignAtt stopping (default: 0.25, matching H100 config)")
parser.add_argument("--chapter-grouped", action="store_true",
help="Group utterances by speaker-chapter (matching H100 methodology)")
parser.add_argument("--output-json", default=None,
help="Save per-utterance results to JSON file")
args = parser.parse_args()
# Check alignment heads
heads_path = ALIGNMENT_HEADS.get(args.model_size)
if heads_path and os.path.exists(heads_path):
logger.info("Using alignment heads: %s", heads_path)
with open(heads_path) as f:
heads_data = json.load(f)
n_heads = len(heads_data.get("alignment_heads_compact", []))
logger.info(" Loaded %d alignment heads for border detection", n_heads)
else:
heads_path = None
logger.warning("No alignment heads file found for %s! Using default heuristic.",
args.model_size)
# Load model
logger.info("Loading Qwen3-ASR-%s MLX SimulStreaming model...", args.model_size.upper())
t_load_start = time.perf_counter()
asr = Qwen3MLXSimulStreamingASR(
model_size=args.model_size,
lan="en",
alignment_heads_path=heads_path,
border_fraction=args.border_fraction,
)
t_load_end = time.perf_counter()
logger.info("Model loaded in %.2fs", t_load_end - t_load_start)
# Verify alignment heads
logger.info("Alignment heads active: %d heads across %d layers",
len(asr.alignment_heads), len(asr.heads_by_layer))
if asr.alignment_heads:
layers = sorted(asr.heads_by_layer.keys())
logger.info(" Active layers: %s", layers[:10])
logger.info(" First 5 heads: %s", asr.alignment_heads[:5])
logger.info("Config: border_fraction=%.2f, chunk_seconds=%.1f",
args.border_fraction, args.chunk_seconds)
# Warmup
logger.info("Running warmup inference...")
dummy_audio = np.random.randn(SAMPLE_RATE * 3).astype(np.float32) * 0.01
if args.single_shot:
_, warmup_time = transcribe_single_shot(asr, dummy_audio)
else:
_, warmup_time = transcribe_simul(asr, dummy_audio, args.chunk_seconds)
logger.info("Warmup done in %.2fs", warmup_time)
# Determine mode
mode = "single-shot" if args.single_shot else "simul-streaming"
if args.chapter_grouped:
mode += " (chapter-grouped)"
logger.info("Starting benchmark: model=%s, mode=%s, bf=%.2f, chunk=%.1fs",
args.model_size, mode, args.border_fraction, args.chunk_seconds)
logger.info("LibriSpeech dir: %s", args.librispeech_dir)
# Load data
if args.chapter_grouped:
samples = load_librispeech_chapters(args.librispeech_dir)
logger.info("Loaded %d speaker-chapters", len(samples))
else:
samples = list(load_librispeech_utterances(
args.librispeech_dir, args.max_utterances
))
logger.info("Loaded %d utterances", len(samples))
# Run benchmark
references = []
hypotheses = []
per_sample_results = []
total_audio_duration = 0.0
total_inference_time = 0.0
for i, (sample_id, audio, ref_text, duration) in enumerate(samples):
if args.single_shot:
hyp_text, infer_time = transcribe_single_shot(asr, audio)
else:
hyp_text, infer_time = transcribe_simul(asr, audio, args.chunk_seconds)
ref_norm = normalize_text(ref_text)
hyp_norm = normalize_text(hyp_text)
# Per-sample WER
if ref_norm:
sample_wer = compute_wer(ref_norm, hyp_norm)
else:
sample_wer = 0.0
total_audio_duration += duration
total_inference_time += infer_time
references.append(ref_norm)
hypotheses.append(hyp_norm)
result = {
"id": sample_id,
"ref": ref_text,
"hyp": hyp_text,
"ref_norm": ref_norm,
"hyp_norm": hyp_norm,
"duration_s": round(duration, 3),
"infer_time_s": round(infer_time, 3),
"rtf": round(infer_time / duration, 4) if duration > 0 else 0,
"wer": round(sample_wer, 4),
}
per_sample_results.append(result)
# Progress logging
if (i + 1) % 50 == 0 or (i + 1) <= 5:
running_wer = compute_wer(references, hypotheses)
running_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
logger.info(
"[%d/%d] id=%s dur=%.1fs infer=%.2fs rtf=%.3f wer=%.1f%% "
"| running: wer=%.2f%% rtf=%.3f",
i + 1, len(samples), sample_id, duration, infer_time,
infer_time / duration if duration > 0 else 0,
sample_wer * 100, running_wer * 100, running_rtf,
)
# Show first few transcriptions
if i < 3:
logger.info(" REF: %s", ref_text[:120])
logger.info(" HYP: %s", hyp_text[:120])
# Final results
n_samples = len(references)
if n_samples == 0:
logger.error("No samples processed!")
return
total_wer = compute_wer(references, hypotheses)
total_cer = compute_cer(references, hypotheses)
total_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
total_ref_words = sum(len(r.split()) for r in references)
total_hyp_words = sum(len(h.split()) for h in hypotheses)
wers = [r["wer"] for r in per_sample_results]
wers_sorted = sorted(wers)
median_wer = wers_sorted[len(wers_sorted) // 2]
p90_wer = wers_sorted[int(len(wers_sorted) * 0.9)]
p95_wer = wers_sorted[int(len(wers_sorted) * 0.95)]
zero_wer_count = sum(1 for w in wers if w == 0.0)
unit = "chapters" if args.chapter_grouped else "utterances"
print("\n" + "=" * 70)
print(f"BENCHMARK RESULTS: Qwen3-ASR-{args.model_size.upper()} MLX SimulStreaming")
print(f"Mode: {mode}")
print(f"Config: border_fraction={args.border_fraction}, chunk={args.chunk_seconds}s")
print("=" * 70)
print(f"Samples ({unit}): {n_samples}")
print(f"Total audio: {total_audio_duration:.1f}s ({total_audio_duration/60:.1f}min)")
print(f"Total inference: {total_inference_time:.1f}s ({total_inference_time/60:.1f}min)")
print(f"Reference words: {total_ref_words}")
print(f"Hypothesis words: {total_hyp_words}")
print("-" * 70)
print(f"WER: {total_wer * 100:.2f}%")
print(f"CER: {total_cer * 100:.2f}%")
print(f"RTF: {total_rtf:.4f}")
if total_rtf > 0:
print(f" (1/RTF = {1/total_rtf:.1f}x realtime)")
print("-" * 70)
print(f"Median {unit[:3]} WER: {median_wer * 100:.2f}%")
print(f"P90 {unit[:3]} WER: {p90_wer * 100:.2f}%")
print(f"P95 {unit[:3]} WER: {p95_wer * 100:.2f}%")
print(f"Zero-WER {unit[:3]}: {zero_wer_count}/{n_samples} ({zero_wer_count/n_samples*100:.1f}%)")
print("-" * 70)
print(f"Alignment heads: {len(asr.alignment_heads)} heads, {len(asr.heads_by_layer)} layers")
print(f"Heads file: {heads_path or 'NONE (default heuristic)'}")
print(f"Model loaded in: {t_load_end - t_load_start:.2f}s")
print("=" * 70)
# H100 reference comparison
print("\nH100 PyTorch SimulStream+KV reference (chapter-grouped, bf=0.25):")
print(" 0.6B: WER 6.44%, RTF 0.109 (91 chapters, 602s)")
print(" 1.7B: WER 8.09%, RTF 0.117 (91 chapters, 602s)")
# Worst samples
worst = sorted(per_sample_results, key=lambda r: r["wer"], reverse=True)[:10]
print(f"\nTop 10 worst {unit}:")
for r in worst:
print(f" {r['id']}: WER={r['wer']*100:.1f}% dur={r['duration_s']:.1f}s rtf={r['rtf']:.3f}")
if r['wer'] > 0.5:
print(f" REF: {r['ref_norm'][:80]}")
print(f" HYP: {r['hyp_norm'][:80]}")
# Save JSON results
if args.output_json:
output = {
"model": f"Qwen3-ASR-{args.model_size.upper()}",
"backend": "mlx-simul-streaming",
"mode": mode,
"platform": "Apple M5 (32GB)",
"config": {
"border_fraction": args.border_fraction,
"chunk_seconds": args.chunk_seconds,
"chapter_grouped": args.chapter_grouped,
},
"n_samples": n_samples,
"total_audio_s": round(total_audio_duration, 2),
"total_inference_s": round(total_inference_time, 2),
"wer": round(total_wer, 6),
"cer": round(total_cer, 6),
"rtf": round(total_rtf, 6),
"median_wer": round(median_wer, 6),
"p90_wer": round(p90_wer, 6),
"p95_wer": round(p95_wer, 6),
"alignment_heads_count": len(asr.alignment_heads),
"alignment_heads_file": heads_path,
"per_sample": per_sample_results,
}
with open(args.output_json, "w") as f:
json.dump(output, f, indent=2)
logger.info("Results saved to %s", args.output_json)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
#!/usr/bin/env python3
"""
Generate combined M5 vs H100 benchmark figure for WhisperLiveKit.
Produces a WER vs RTF scatter plot comparing Apple M5 (MLX) and
NVIDIA H100 results on LibriSpeech test-clean.
Note: M5 uses per-utterance evaluation (500 samples), while H100
uses chapter-grouped evaluation (91 chapters). Per-utterance WER
is typically lower because short utterances avoid long-range errors.
Run: python3 benchmarks/m5/generate_figures.py
"""
import json
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
DIR = os.path.dirname(os.path.abspath(__file__))
H100_DATA = json.load(open(os.path.join(DIR, "..", "h100", "results.json")))
M5_DATA = json.load(open(os.path.join(DIR, "results.json")))
# -- Style --
plt.rcParams.update({
"font.family": "sans-serif",
"font.size": 11,
"axes.spines.top": False,
"axes.spines.right": False,
})
COLORS = {
"whisper": "#d63031",
"qwen_b": "#6c5ce7",
"qwen_s": "#00b894",
"voxtral": "#fdcb6e",
"m5_qwen": "#0984e3",
}
def _save(fig, name):
path = os.path.join(DIR, name)
fig.savefig(path, dpi=180, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" saved: {name}")
def fig_m5_vs_h100():
"""WER vs RTF scatter: M5 (MLX) and H100 (CUDA) on LibriSpeech test-clean."""
h100 = H100_DATA["librispeech_clean"]["systems"]
m5 = M5_DATA["models"]
fig, ax = plt.subplots(figsize=(10, 7))
# Light green band for "good WER" zone
ax.axhspan(0, 5, color="#f0fff0", alpha=0.5, zorder=0)
# --- H100 points ---
h100_pts = [
("Whisper large-v3\n(H100, batch)", h100["whisper_large_v3_batch"], COLORS["whisper"], "h", 220),
("Qwen3 0.6B batch\n(H100)", h100["qwen3_0.6b_batch"], COLORS["qwen_b"], "h", 170),
("Qwen3 1.7B batch\n(H100)", h100["qwen3_1.7b_batch"], COLORS["qwen_b"], "h", 220),
("Voxtral 4B vLLM\n(H100)", h100["voxtral_4b_vllm_realtime"], COLORS["voxtral"], "D", 240),
("Qwen3 0.6B SimulStream+KV\n(H100)", h100["qwen3_0.6b_simulstream_kv"], COLORS["qwen_s"], "s", 200),
("Qwen3 1.7B SimulStream+KV\n(H100)", h100["qwen3_1.7b_simulstream_kv"], COLORS["qwen_s"], "s", 260),
]
h100_offsets = [(-55, 10), (-55, -22), (8, -18), (8, 10), (8, 10), (8, -18)]
for (name, d, color, marker, sz), (lx, ly) in zip(h100_pts, h100_offsets):
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
edgecolors="white", linewidths=1.5, zorder=5)
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
xytext=(lx, ly), textcoords="offset points",
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
# --- M5 points ---
m5_pts = [
("Qwen3 0.6B SimulStream\n(M5, MLX)", m5["qwen3-asr-0.6b-simul"], COLORS["m5_qwen"], "^", 260),
("Qwen3 1.7B SimulStream\n(M5, MLX)", m5["qwen3-asr-1.7b-simul"], COLORS["m5_qwen"], "^", 300),
]
m5_offsets = [(8, 8), (8, -18)]
for (name, d, color, marker, sz), (lx, ly) in zip(m5_pts, m5_offsets):
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
edgecolors="white", linewidths=1.5, zorder=6)
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
xytext=(lx, ly), textcoords="offset points",
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
# --- Connecting lines between same models on different hardware ---
# 0.6B: H100 SimulStream+KV -> M5 SimulStream
ax.plot([h100["qwen3_0.6b_simulstream_kv"]["rtf"], m5["qwen3-asr-0.6b-simul"]["rtf"]],
[h100["qwen3_0.6b_simulstream_kv"]["wer"], m5["qwen3-asr-0.6b-simul"]["wer"]],
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
# 1.7B: H100 SimulStream+KV -> M5 SimulStream
ax.plot([h100["qwen3_1.7b_simulstream_kv"]["rtf"], m5["qwen3-asr-1.7b-simul"]["rtf"]],
[h100["qwen3_1.7b_simulstream_kv"]["wer"], m5["qwen3-asr-1.7b-simul"]["wer"]],
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
# --- RTF = 1 line (real-time boundary) ---
ax.axvline(x=1.0, color="#e17055", linestyle=":", alpha=0.5, lw=1.5, zorder=1)
ax.text(1.02, 0.5, "real-time\nboundary", fontsize=8, color="#e17055",
fontstyle="italic", alpha=0.7, va="bottom")
# --- Methodology note ---
ax.text(0.98, 0.02,
"H100: chapter-grouped WER (91 chapters) | M5: per-utterance WER (500 samples)\n"
"Per-utterance WER is typically lower -- results are not directly comparable.",
transform=ax.transAxes, fontsize=7.5, color="#666",
ha="right", va="bottom", fontstyle="italic",
bbox=dict(boxstyle="round,pad=0.3", fc="#fff9e6", ec="#ddd", alpha=0.9))
ax.set_xlabel("RTF (lower = faster)")
ax.set_ylabel("WER % (lower = better)")
ax.set_title("H100 vs M5 (MLX) -- Qwen3-ASR on LibriSpeech test-clean",
fontsize=13, fontweight="bold", pad=12)
ax.set_xlim(-0.01, 1.1)
ax.set_ylim(-0.5, 10)
ax.grid(True, alpha=0.12)
legend = [
mpatches.Patch(color=COLORS["whisper"], label="Whisper large-v3 (H100)"),
mpatches.Patch(color=COLORS["qwen_b"], label="Qwen3-ASR batch (H100)"),
mpatches.Patch(color=COLORS["qwen_s"], label="Qwen3 SimulStream+KV (H100)"),
mpatches.Patch(color=COLORS["voxtral"], label="Voxtral 4B vLLM (H100)"),
mpatches.Patch(color=COLORS["m5_qwen"], label="Qwen3 SimulStream (M5, MLX)"),
plt.Line2D([0], [0], marker="h", color="w", mfc="gray", ms=8, label="Batch mode"),
plt.Line2D([0], [0], marker="s", color="w", mfc="gray", ms=8, label="Streaming (H100)"),
plt.Line2D([0], [0], marker="^", color="w", mfc="gray", ms=8, label="Streaming (M5)"),
]
ax.legend(handles=legend, fontsize=8, loc="upper right", framealpha=0.85, ncol=2)
_save(fig, "m5_vs_h100_wer_rtf.png")
if __name__ == "__main__":
print("Generating M5 vs H100 benchmark figure...")
fig_m5_vs_h100()
print("Done!")

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

View File

@@ -0,0 +1,9 @@
{
"platform": "Apple M5 (32GB RAM, MLX fp16)",
"dataset": "LibriSpeech test-clean",
"methodology": "per-utterance (500 samples)",
"models": {
"qwen3-asr-0.6b-simul": {"wer": 3.30, "rtf": 0.263},
"qwen3-asr-1.7b-simul": {"wer": 4.07, "rtf": 0.944}
}
}

View File

@@ -121,6 +121,15 @@ class TranscriptionEngine:
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3-mlx-simul":
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingASR
self.tokenizer = None
self.asr = Qwen3MLXSimulStreamingASR(
**transcription_common_params,
alignment_heads_path=config.custom_alignment_heads,
border_fraction=getattr(config, 'border_fraction', 0.15),
)
logger.info("Using Qwen3 MLX SimulStreaming backend")
elif config.backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
self.tokenizer = None
@@ -247,6 +256,9 @@ def online_factory(args, asr, language=None):
if backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
return Qwen3SimulKVOnlineProcessor(asr)
if backend == "qwen3-mlx-simul":
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingOnlineProcessor
return Qwen3MLXSimulStreamingOnlineProcessor(asr)
if backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
return Qwen3MLXOnlineProcessor(asr)

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-mlx-simul", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3-mlx-simul' for Qwen3-ASR SimulStreaming on Apple Silicon (MLX). Use 'qwen3-mlx' for Qwen3-ASR LocalAgreement on MLX. Use 'qwen3-simul' for Qwen3-ASR SimulStreaming (PyTorch). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
)
parser.add_argument(
"--no-vac",

View File

@@ -0,0 +1,746 @@
"""
Qwen3-ASR SimulStreaming (AlignAtt) on MLX for Apple Silicon.
Uses the ``mlx_qwen3_asr`` library for model loading, audio encoding, and
tokenization. Implements the AlignAtt border-distance policy by monkey-
patching ``TextAttention.__call__`` on alignment layers to capture Q (with
RoPE) during autoregressive decode steps, then computing ``Q @ K_audio^T``
from the KV cache to find the most-attended audio frame.
This is the MLX equivalent of ``qwen3_simul.py`` (PyTorch) which uses
``register_forward_hook`` for the same purpose.
"""
import json
import logging
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16_000
# Model size aliases (same as qwen3_mlx_asr.py)
QWEN3_MLX_MODEL_MAPPING = {
"base": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"1.7b": "Qwen/Qwen3-ASR-1.7B",
"0.6b": "Qwen/Qwen3-ASR-0.6B",
}
# Whisper language codes -> Qwen3 canonical language names
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class Qwen3MLXSimulConfig:
language: str = "auto"
alignment_heads_path: Optional[str] = None
border_fraction: float = 0.15
rewind_fraction: float = 0.12
audio_min_len: float = 0.5
audio_max_len: float = 15.0
max_context_tokens: int = 30
max_alignment_heads: int = 20
# ---------------------------------------------------------------------------
# Per-session state
# ---------------------------------------------------------------------------
@dataclass
class _SessionState:
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
last_attend_frame: int = -15
committed_word_count: int = 0
committed_token_ids: List[int] = field(default_factory=list)
detected_language: Optional[str] = None
last_infer_samples: int = 0
# ---------------------------------------------------------------------------
# Shared model holder
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingASR:
"""Loads the Qwen3-ASR model via ``mlx_qwen3_asr`` once and keeps it
alive for the lifetime of the server. Shared across sessions."""
sep = ""
SAMPLING_RATE = SAMPLE_RATE
def __init__(
self,
model_size: str = None,
model_dir: str = None,
model_path: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
lora_path: Optional[str] = None,
min_chunk_size: float = 0.1,
direct_english_translation: bool = False,
**kwargs,
):
import mlx.core as mx
import mlx_qwen3_asr
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3MLXSimulConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
# Resolve model path
resolved = model_dir or model_path
if not resolved:
size = (model_size or "base").lower()
if "/" in size or size.startswith("."):
resolved = size
else:
resolved = QWEN3_MLX_MODEL_MAPPING.get(size, "Qwen/Qwen3-ASR-0.6B")
t0 = time.time()
logger.info("Loading Qwen3-ASR MLX model '%s' for SimulStreaming ...", resolved)
self.model, self._config = mlx_qwen3_asr.load_model(resolved, dtype=mx.float16)
logger.info("Model loaded in %.2fs", time.time() - t0)
# Tokenizer
tok_path = getattr(self.model, "_resolved_model_path", None) or resolved
self.tokenizer = mlx_qwen3_asr.tokenizer.Tokenizer(str(tok_path))
# Architecture info
text_cfg = self._config.text_config
self.num_layers = text_cfg.num_hidden_layers
self.num_heads = text_cfg.num_attention_heads
self.num_kv_heads = text_cfg.num_key_value_heads
self.head_dim = text_cfg.head_dim
self.gqa_ratio = self.num_heads // self.num_kv_heads
self.audio_token_id = self._config.audio_token_id
logger.info(
"Qwen3-ASR arch: %d layers x %d heads (%d kv), head_dim=%d, GQA=%d",
self.num_layers, self.num_heads, self.num_kv_heads,
self.head_dim, self.gqa_ratio,
)
# Alignment heads
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
self.heads_by_layer = {}
for layer_idx, head_idx in self.alignment_heads:
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
self.backend_choice = "qwen3-mlx-simul"
# Warmup
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
self._warmup(audio)
def _load_alignment_heads(
self, path: Optional[str],
) -> List[Tuple[int, int]]:
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info(
"Loaded top %d alignment heads from %s (of %d total)",
len(heads), path, len(all_heads),
)
return heads
# Default heuristic: last quarter of layers, all heads
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning(
"No alignment heads file. Using default heuristic: "
"%d heads from layers %d-%d.",
len(default_heads), start_layer, self.num_layers - 1,
)
return default_heads[:max_heads]
def _warmup(self, audio: np.ndarray):
import mlx.core as mx
try:
from mlx_qwen3_asr.audio import compute_features
audio = audio[:SAMPLE_RATE * 2]
mel, feat_lens = compute_features(audio)
mel = mel.astype(mx.float16)
audio_features, _ = self.model.audio_tower(mel, feat_lens)
n_audio = int(audio_features.shape[1])
prompt = self.tokenizer.build_prompt_tokens(n_audio, language="English")
input_ids = mx.array([prompt])
positions = mx.arange(input_ids.shape[1])[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
cache = self.model.create_cache()
logits = self.model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
logger.info("Qwen3 MLX SimulStreaming warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
pass # all work in the online processor
# ---------------------------------------------------------------------------
# Attention capture via wrapper replacement
# ---------------------------------------------------------------------------
class _AttnCaptureWrapper:
"""Wraps a TextAttention module to capture alignment scores during decode.
Replaces ``layer.self_attn`` with this wrapper. On decode steps (L=1),
recomputes Q with RoPE, reads cached K from the audio region, computes
``Q @ K_audio^T`` for alignment heads, and stores the argmax frame in
``capture["step_frames"]``.
Python dunder resolution (``__call__``) goes through the *class*, not the
instance, so monkey-patching ``attn.__call__`` on an ``nn.Module`` does
not work. This wrapper class defines its own ``__call__`` and delegates
everything else to the wrapped module via ``__getattr__``.
"""
def __init__(self, original, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture):
# Store in __dict__ directly to avoid triggering __getattr__
self.__dict__["_original"] = original
self.__dict__["_layer_idx"] = layer_idx
self.__dict__["_head_indices"] = head_indices
self.__dict__["_gqa_ratio"] = gqa_ratio
self.__dict__["_audio_start"] = audio_start
self.__dict__["_audio_end"] = audio_end
self.__dict__["_capture"] = capture
def __call__(self, x, cos, sin, mask=None, cache=None, layer_idx=0):
import mlx.core as mx
from mlx_qwen3_asr.mrope import apply_rotary_pos_emb
orig = self.__dict__["_original"]
B, L, _ = x.shape
if L == 1 and cache is not None:
li = self.__dict__["_layer_idx"]
h_indices = self.__dict__["_head_indices"]
gqa = self.__dict__["_gqa_ratio"]
a_start = self.__dict__["_audio_start"]
a_end = self.__dict__["_audio_end"]
cap = self.__dict__["_capture"]
# Recompute Q with RoPE (cheap: single token)
q = orig.q_proj(x)
q = q.reshape(B, L, orig.num_heads, orig.head_dim)
q = orig.q_norm(q)
q = q.transpose(0, 2, 1, 3) # (B, H, 1, D)
q_rope, _ = apply_rotary_pos_emb(q, q, cos, sin)
# K from cache (already has RoPE baked in from cache.update)
k_cached = cache.keys[li]
if k_cached is not None and a_end <= k_cached.shape[2]:
for h_idx in h_indices:
kv_h = h_idx // gqa
q_h = q_rope[0, h_idx, 0] # (head_dim,)
k_audio = k_cached[0, kv_h, a_start:a_end] # (n_audio, D)
scores = k_audio @ q_h # (n_audio,)
frame = int(mx.argmax(scores).item())
cap["step_frames"].append(frame)
return orig(x, cos, sin, mask=mask, cache=cache, layer_idx=layer_idx)
def __getattr__(self, name):
return getattr(self.__dict__["_original"], name)
def _install_alignment_hooks(model, heads_by_layer, gqa_ratio, audio_start, audio_end, capture):
"""Replace ``self_attn`` on alignment layers with capture wrappers.
Returns a list of ``(layer_idx, original_attn)`` for later restoration.
"""
originals = []
for layer_idx, head_indices in heads_by_layer.items():
if layer_idx >= len(model.model.layers):
continue
layer = model.model.layers[layer_idx]
orig_attn = layer.self_attn
wrapper = _AttnCaptureWrapper(
orig_attn, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture,
)
layer.self_attn = wrapper
originals.append((layer_idx, orig_attn))
return originals
def _remove_alignment_hooks(model, originals):
"""Restore original self_attn modules."""
for layer_idx, orig_attn in originals:
model.model.layers[layer_idx].self_attn = orig_attn
# ---------------------------------------------------------------------------
# Per-session online processor
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingOnlineProcessor:
"""Per-session processor implementing AlignAtt on MLX.
Same interface as other online processors:
insert_audio_chunk / process_iter / get_buffer / start_silence /
end_silence / finish / warmup / new_speaker.
"""
SAMPLING_RATE = SAMPLE_RATE
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3MLXSimulStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
self.state = _SessionState()
# -- properties expected by AudioProcessor --
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
# Trim if too long
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
# -- main processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
# Throttle: at least 1s of new audio
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
if not is_last and new_samples < int(1.0 * self.SAMPLING_RATE):
return [], self.end
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
words = self._infer(is_last)
except Exception as e:
logger.exception("Qwen3 MLX SimulStreaming inference error: %s", e)
return [], self.end
if not words:
return [], self.end
self.buffer = []
return words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run one inference cycle with alignment-head-based stopping."""
import mlx.core as mx
from mlx_qwen3_asr.audio import compute_features
from mlx_qwen3_asr.generate import _detect_repetition
asr = self.asr
state = self.state
model = asr.model
# 1. Encode audio
mel, feat_lens = compute_features(state.audio_buffer)
mel = mel.astype(mx.float16)
audio_features, _ = model.audio_tower(mel, feat_lens)
n_audio_tokens = int(audio_features.shape[1])
mx.eval(audio_features)
if n_audio_tokens == 0:
return []
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# 2. Build prompt tokens
lan = asr.cfg.language
language = None
if lan and lan != "auto":
language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
prompt_tokens = asr.tokenizer.build_prompt_tokens(
n_audio_tokens=n_audio_tokens,
language=language,
)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
prompt_tokens.extend(ctx)
input_ids = mx.array([prompt_tokens])
seq_len = input_ids.shape[1]
# 3. Find audio token range
audio_positions = [
i for i, t in enumerate(prompt_tokens) if t == asr.audio_token_id
]
if not audio_positions:
return []
audio_start = audio_positions[0]
audio_end = audio_positions[-1] + 1
# 4. MRoPE position IDs
positions = mx.arange(seq_len, dtype=mx.int32)[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
# 5. Prefill
cache = model.create_cache(max_seq_len=seq_len + 120)
logits = model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
# 6. Install alignment hooks
capture = {"step_frames": []}
originals = _install_alignment_hooks(
model, asr.heads_by_layer, asr.gqa_ratio,
audio_start, audio_end, capture,
)
# 7. Decode loop with border-distance policy
eos_ids = set(asr.tokenizer.EOS_TOKEN_IDS)
per_step_frames: List[List[int]] = []
last_attend_frame = state.last_attend_frame
border_stop_step: Optional[int] = None
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
# Max tokens: ~6 tokens/sec of speech + margin
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
if is_last:
max_tokens = min(int(audio_duration * 6) + 10, 120)
else:
max_tokens = min(int(max(new_audio_secs, 1.0) * 6) + 5, 40)
token = int(mx.argmax(logits.reshape(-1)).item())
generated = [token]
try:
for step in range(1, max_tokens):
if token in eos_ids:
break
if _detect_repetition(generated):
break
next_ids = mx.array([[token]])
pos_val = seq_len + step - 1
next_pos = mx.array([[[pos_val], [pos_val], [pos_val]]], dtype=mx.int32)
logits = model.step(next_ids, next_pos, cache, validate_input_ids=False)
mx.eval(logits)
token = int(mx.argmax(logits.reshape(-1)).item())
generated.append(token)
# Collect frames from this step
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
capture["step_frames"] = []
# Border-distance check (skip first 3 steps)
if (not is_last
and border_stop_step is None
and len(per_step_frames) >= 3):
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
# Rewind check
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
break
last_attend_frame = attended
# Border check
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
break
# Periodic eval to prevent graph buildup
if step % 8 == 0:
mx.eval(cache.keys[-1])
finally:
_remove_alignment_hooks(model, originals)
# Flush remaining frames
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
state.last_attend_frame = last_attend_frame
# 8. Process generated tokens
# Remove trailing EOS
while generated and generated[-1] in eos_ids:
generated.pop()
num_gen = len(generated)
if num_gen == 0:
return []
raw_text = asr.tokenizer.decode(generated)
logger.info(
"SimulStreaming raw: %d tokens (border_stop=%s), text=%r",
num_gen, border_stop_step, raw_text[:100],
)
# 9. Strip metadata prefix ("language English<asr_text>...")
from mlx_qwen3_asr.tokenizer import parse_asr_output
detected_lang, clean_text = parse_asr_output(
raw_text,
user_language=language,
)
# Find how many tokens to skip for metadata
metadata_offset = 0
asr_text_tokens = asr.tokenizer.encode("<asr_text>")
asr_text_id = asr_text_tokens[0] if asr_text_tokens else None
if asr_text_id is not None:
for i in range(min(num_gen, 10)):
if generated[i] == asr_text_id:
metadata_offset = i + 1
break
if metadata_offset > 0:
generated = generated[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Detect language
if state.detected_language is None and detected_lang and detected_lang != "unknown":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
detected_lang, detected_lang.lower(),
)
logger.info("Auto-detected language: %s", state.detected_language)
# 10. Determine how many tokens to emit
step_frames = [f for f in per_step_frames if f]
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
if emit_up_to <= 0:
return []
emitted_ids = generated[:emit_up_to]
# 11. Build timestamped words
words = self._build_timestamped_words(
emitted_ids, step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
# Update state
state.committed_word_count += len(words)
state.committed_token_ids.extend(emitted_ids)
return words
def _build_timestamped_words(
self,
generated_ids: List[int],
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
"""Build timestamped ASRToken list from generated tokens and
alignment-head captured frames."""
state = self.state
asr = self.asr
# Per-token attended frame (median of head votes)
per_token_frame: List[Optional[int]] = []
for step_idx in range(emit_up_to):
if step_idx < len(step_frames) and step_frames[step_idx]:
frames = sorted(step_frames[step_idx])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
# Decode full text, split into words
full_text = asr.tokenizer.decode(generated_ids[:emit_up_to])
text_words = full_text.split()
# Map words to frames proportionally
all_frames = [f for f in per_token_frame if f is not None]
word_frame_pairs = []
for wi, word in enumerate(text_words):
if all_frames:
frac = wi / max(len(text_words), 1)
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
word_frame_pairs.append((word, frame))
# Convert to ASRToken
tokens = []
for i, (text, frame) in enumerate(word_frame_pairs):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(word_frame_pairs), 1)) * audio_duration
+ state.cumulative_time_offset
)
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
# -- silence / speaker / lifecycle --
def start_silence(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
gap_silence = np.zeros(gap_len, dtype=np.float32)
self.state.audio_buffer = np.append(
self.state.audio_buffer, gap_silence,
)
else:
self.state = _SessionState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker):
self.process_iter(is_last=True)
self.state = _SessionState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = _SessionState()
logger.info("Qwen3 MLX SimulStreaming processor warmed up")
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = _SessionState()
def finish(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end