9 Commits

Author SHA1 Message Date
Quentin Fuxa
7fc341683a Remove unused benchmark and cascade bridge scripts 2026-03-31 23:10:44 +02:00
Quentin Fuxa
3f233dc36c Fix all ruff lint errors (68 errors → 0)
- Remove unused imports and variables (F401, F841)
- Sort import blocks (I001)
- Split semicolon-separated statements (E702)
- Fix backslash in f-string for Python 3.11 compat (cli.py)
- Remove empty f-strings (F541)
- Add noqa for intentional E402 after sys.path manipulation
2026-03-31 23:02:50 +02:00
Quentin Fuxa
db526ded34 Fix diarization failing on clips longer than ~1min (#349)
The diarization processor called diarize() only once per queue drain.
When audio was fed faster than real-time, the backend buffered all audio
but only processed one chunk, then blocked waiting for more from the
queue. Remaining buffered audio was never diarized, producing empty
segments.

- Drain the diarization buffer fully after each audio insertion
- Drain remaining buffer on SENTINEL before exiting
- Use extend instead of replace for incremental diarization segments
2026-03-31 22:55:34 +02:00
Quentin Fuxa
3e5d8c5820 Fix Qwen3 streaming decode budget and head loading 2026-03-23 23:03:11 +01:00
Quentin Fuxa
b102e12943 M5 benchmark figures: WER vs RTF scatter, 0.6B+1.7B MLX results 2026-03-15 15:00:00 +01:00
Quentin Fuxa
7aa3b764bd MLX benchmark: 1.7B SimulStreaming on M5 (WER 4.07%, RTF 0.944)
LibriSpeech test-clean, 500 utterances.
1.7B is borderline real-time on M5 (RTF 0.944).
0.6B (3.30% WER, 0.263 RTF) is the practical choice for MacBook.
2026-03-15 14:00:00 +01:00
Quentin Fuxa
a422e604ae MLX benchmark: 0.6B SimulStreaming on M5 MacBook (WER 3.30%, RTF 0.263)
LibriSpeech test-clean, 500 utterances, per-utterance simul-streaming.
AlignAtt border detection with 20 alignment heads.
Platform: Apple M5 32GB (MLX fp16).

benchmark_mlx_simul.py: reusable benchmark script for MLX backends.
2026-03-15 13:00:00 +01:00
Quentin Fuxa
e14b913807 Merge branch 'benchmarks-h100' 2026-03-15 12:00:00 +01:00
Quentin Fuxa
3b7a2fcc87 Add Qwen3-ASR MLX SimulStreaming backend
New backend 'qwen3-mlx-simul' for Apple Silicon: AlignAtt border
detection via monkey-patched cross-attention on MLX Qwen3-ASR.
Supports 0.6B (RTF 0.236 on M5) and 1.7B models.

- qwen3_mlx_simul.py: full streaming implementation with KV cache,
  alignment head attention extraction, border-distance policy
- core.py: register new backend in TranscriptionEngine + online_factory
- parse_args.py: add qwen3-mlx-simul to CLI choices
2026-03-15 11:00:00 +01:00
21 changed files with 12551 additions and 411 deletions

460
benchmark_mlx_simul.py Normal file
View File

@@ -0,0 +1,460 @@
#!/usr/bin/env python3
"""
Benchmark Qwen3-ASR MLX SimulStreaming on LibriSpeech test-clean.
Measures:
- Word Error Rate (WER) via jiwer
- Real-Time Factor (RTF) = total_inference_time / total_audio_duration
- Per-utterance stats
Usage:
# Per-utterance simul-streaming (default)
python benchmark_mlx_simul.py --model-size 0.6b
# Single-shot (batch-like, no streaming chunking)
python benchmark_mlx_simul.py --model-size 0.6b --single-shot
# Quick test with 100 utterances
python benchmark_mlx_simul.py --model-size 0.6b --max-utterances 100
# Chapter-grouped (matching H100 benchmark methodology)
python benchmark_mlx_simul.py --model-size 0.6b --chapter-grouped
"""
import argparse
import json
import logging
import os
import re
import sys
import time
from pathlib import Path
import numpy as np
import soundfile as sf
from jiwer import cer as compute_cer
from jiwer import wer as compute_wer
# Add WhisperLiveKit to path
WLKIT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(WLKIT_DIR))
from whisperlivekit.qwen3_mlx_simul import ( # noqa: E402
Qwen3MLXSimulStreamingASR,
Qwen3MLXSimulStreamingOnlineProcessor,
)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("benchmark")
logger.setLevel(logging.INFO)
SAMPLE_RATE = 16_000
# Alignment heads paths
ALIGNMENT_HEADS = {
"0.6b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_0.6B.json"),
"1.7b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_1.7B_v2.json"),
}
def load_librispeech_utterances(data_dir: str, max_utterances: int = 0):
"""Load LibriSpeech utterances: yields (utt_id, audio_np, reference_text, duration_s)."""
data_path = Path(data_dir)
trans_files = sorted(data_path.rglob("*.trans.txt"))
count = 0
for trans_file in trans_files:
chapter_dir = trans_file.parent
with open(trans_file) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(" ", 1)
utt_id = parts[0]
ref_text = parts[1] if len(parts) > 1 else ""
flac_path = chapter_dir / f"{utt_id}.flac"
if not flac_path.exists():
logger.warning("Missing FLAC: %s", flac_path)
continue
audio, sr = sf.read(str(flac_path), dtype="float32")
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
duration = len(audio) / SAMPLE_RATE
yield utt_id, audio, ref_text, duration
count += 1
if max_utterances > 0 and count >= max_utterances:
return
def load_librispeech_chapters(data_dir: str):
"""Load LibriSpeech grouped by speaker-chapter.
Concatenates all utterances within each speaker/chapter into one long audio.
Returns list of (chapter_id, audio_np, reference_text, duration_s).
"""
data_path = Path(data_dir)
trans_files = sorted(data_path.rglob("*.trans.txt"))
chapters = []
for trans_file in trans_files:
chapter_dir = trans_file.parent
chapter_id = chapter_dir.name
speaker_id = chapter_dir.parent.name
full_id = f"{speaker_id}-{chapter_id}"
audios = []
refs = []
with open(trans_file) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(" ", 1)
utt_id = parts[0]
ref_text = parts[1] if len(parts) > 1 else ""
flac_path = chapter_dir / f"{utt_id}.flac"
if not flac_path.exists():
continue
audio, sr = sf.read(str(flac_path), dtype="float32")
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
audios.append(audio)
refs.append(ref_text)
if audios:
# Concatenate with 0.5s silence between utterances
silence = np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)
combined = []
for j, a in enumerate(audios):
if j > 0:
combined.append(silence)
combined.append(a)
combined_audio = np.concatenate(combined)
combined_ref = " ".join(refs)
duration = len(combined_audio) / SAMPLE_RATE
chapters.append((full_id, combined_audio, combined_ref, duration))
return chapters
def transcribe_simul(asr, audio, chunk_seconds=2.0):
"""Transcribe using SimulStreaming with chunked audio feed.
Returns (transcription_text, inference_time_seconds).
"""
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
chunk_size = int(chunk_seconds * SAMPLE_RATE)
total_samples = len(audio)
offset = 0
all_tokens = []
t0 = time.perf_counter()
while offset < total_samples:
end = min(offset + chunk_size, total_samples)
chunk = audio[offset:end]
stream_time = end / SAMPLE_RATE
processor.insert_audio_chunk(chunk, stream_time)
is_last = (end >= total_samples)
tokens, _ = processor.process_iter(is_last=is_last)
if tokens:
all_tokens.extend(tokens)
offset = end
# Final flush
final_tokens, _ = processor.finish()
if final_tokens:
all_tokens.extend(final_tokens)
t1 = time.perf_counter()
inference_time = t1 - t0
text = "".join(t.text for t in all_tokens).strip()
return text, inference_time
def transcribe_single_shot(asr, audio):
"""Transcribe by feeding all audio at once (batch-like).
Returns (transcription_text, inference_time_seconds).
"""
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
t0 = time.perf_counter()
duration = len(audio) / SAMPLE_RATE
processor.insert_audio_chunk(audio, duration)
all_tokens, _ = processor.process_iter(is_last=True)
# Flush
final_tokens, _ = processor.finish()
if final_tokens:
all_tokens.extend(final_tokens)
t1 = time.perf_counter()
inference_time = t1 - t0
text = "".join(t.text for t in all_tokens).strip()
return text, inference_time
def normalize_text(text: str) -> str:
"""Normalize text for WER computation: uppercase, strip punctuation."""
text = text.upper()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def main():
parser = argparse.ArgumentParser(description="Benchmark Qwen3-ASR MLX SimulStreaming")
parser.add_argument("--model-size", default="0.6b", choices=["0.6b", "1.7b"],
help="Model size (default: 0.6b)")
parser.add_argument("--max-utterances", type=int, default=0,
help="Max utterances to process (0=all). Ignored in chapter mode.")
parser.add_argument("--librispeech-dir", default="/tmp/LibriSpeech/test-clean",
help="Path to LibriSpeech test-clean directory")
parser.add_argument("--single-shot", action="store_true",
help="Feed entire audio at once instead of streaming chunks")
parser.add_argument("--chunk-seconds", type=float, default=2.0,
help="Chunk size in seconds for simul-streaming (default: 2.0)")
parser.add_argument("--border-fraction", type=float, default=0.25,
help="Border fraction for AlignAtt stopping (default: 0.25, matching H100 config)")
parser.add_argument("--chapter-grouped", action="store_true",
help="Group utterances by speaker-chapter (matching H100 methodology)")
parser.add_argument("--output-json", default=None,
help="Save per-utterance results to JSON file")
args = parser.parse_args()
# Check alignment heads
heads_path = ALIGNMENT_HEADS.get(args.model_size)
if heads_path and os.path.exists(heads_path):
logger.info("Using alignment heads: %s", heads_path)
with open(heads_path) as f:
heads_data = json.load(f)
n_heads = len(heads_data.get("alignment_heads_compact", []))
logger.info(" Loaded %d alignment heads for border detection", n_heads)
else:
heads_path = None
logger.warning("No alignment heads file found for %s! Using default heuristic.",
args.model_size)
# Load model
logger.info("Loading Qwen3-ASR-%s MLX SimulStreaming model...", args.model_size.upper())
t_load_start = time.perf_counter()
asr = Qwen3MLXSimulStreamingASR(
model_size=args.model_size,
lan="en",
alignment_heads_path=heads_path,
border_fraction=args.border_fraction,
)
t_load_end = time.perf_counter()
logger.info("Model loaded in %.2fs", t_load_end - t_load_start)
# Verify alignment heads
logger.info("Alignment heads active: %d heads across %d layers",
len(asr.alignment_heads), len(asr.heads_by_layer))
if asr.alignment_heads:
layers = sorted(asr.heads_by_layer.keys())
logger.info(" Active layers: %s", layers[:10])
logger.info(" First 5 heads: %s", asr.alignment_heads[:5])
logger.info("Config: border_fraction=%.2f, chunk_seconds=%.1f",
args.border_fraction, args.chunk_seconds)
# Warmup
logger.info("Running warmup inference...")
dummy_audio = np.random.randn(SAMPLE_RATE * 3).astype(np.float32) * 0.01
if args.single_shot:
_, warmup_time = transcribe_single_shot(asr, dummy_audio)
else:
_, warmup_time = transcribe_simul(asr, dummy_audio, args.chunk_seconds)
logger.info("Warmup done in %.2fs", warmup_time)
# Determine mode
mode = "single-shot" if args.single_shot else "simul-streaming"
if args.chapter_grouped:
mode += " (chapter-grouped)"
logger.info("Starting benchmark: model=%s, mode=%s, bf=%.2f, chunk=%.1fs",
args.model_size, mode, args.border_fraction, args.chunk_seconds)
logger.info("LibriSpeech dir: %s", args.librispeech_dir)
# Load data
if args.chapter_grouped:
samples = load_librispeech_chapters(args.librispeech_dir)
logger.info("Loaded %d speaker-chapters", len(samples))
else:
samples = list(load_librispeech_utterances(
args.librispeech_dir, args.max_utterances
))
logger.info("Loaded %d utterances", len(samples))
# Run benchmark
references = []
hypotheses = []
per_sample_results = []
total_audio_duration = 0.0
total_inference_time = 0.0
for i, (sample_id, audio, ref_text, duration) in enumerate(samples):
if args.single_shot:
hyp_text, infer_time = transcribe_single_shot(asr, audio)
else:
hyp_text, infer_time = transcribe_simul(asr, audio, args.chunk_seconds)
ref_norm = normalize_text(ref_text)
hyp_norm = normalize_text(hyp_text)
# Per-sample WER
if ref_norm:
sample_wer = compute_wer(ref_norm, hyp_norm)
else:
sample_wer = 0.0
total_audio_duration += duration
total_inference_time += infer_time
references.append(ref_norm)
hypotheses.append(hyp_norm)
result = {
"id": sample_id,
"ref": ref_text,
"hyp": hyp_text,
"ref_norm": ref_norm,
"hyp_norm": hyp_norm,
"duration_s": round(duration, 3),
"infer_time_s": round(infer_time, 3),
"rtf": round(infer_time / duration, 4) if duration > 0 else 0,
"wer": round(sample_wer, 4),
}
per_sample_results.append(result)
# Progress logging
if (i + 1) % 50 == 0 or (i + 1) <= 5:
running_wer = compute_wer(references, hypotheses)
running_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
logger.info(
"[%d/%d] id=%s dur=%.1fs infer=%.2fs rtf=%.3f wer=%.1f%% "
"| running: wer=%.2f%% rtf=%.3f",
i + 1, len(samples), sample_id, duration, infer_time,
infer_time / duration if duration > 0 else 0,
sample_wer * 100, running_wer * 100, running_rtf,
)
# Show first few transcriptions
if i < 3:
logger.info(" REF: %s", ref_text[:120])
logger.info(" HYP: %s", hyp_text[:120])
# Final results
n_samples = len(references)
if n_samples == 0:
logger.error("No samples processed!")
return
total_wer = compute_wer(references, hypotheses)
total_cer = compute_cer(references, hypotheses)
total_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
total_ref_words = sum(len(r.split()) for r in references)
total_hyp_words = sum(len(h.split()) for h in hypotheses)
wers = [r["wer"] for r in per_sample_results]
wers_sorted = sorted(wers)
median_wer = wers_sorted[len(wers_sorted) // 2]
p90_wer = wers_sorted[int(len(wers_sorted) * 0.9)]
p95_wer = wers_sorted[int(len(wers_sorted) * 0.95)]
zero_wer_count = sum(1 for w in wers if w == 0.0)
unit = "chapters" if args.chapter_grouped else "utterances"
print("\n" + "=" * 70)
print(f"BENCHMARK RESULTS: Qwen3-ASR-{args.model_size.upper()} MLX SimulStreaming")
print(f"Mode: {mode}")
print(f"Config: border_fraction={args.border_fraction}, chunk={args.chunk_seconds}s")
print("=" * 70)
print(f"Samples ({unit}): {n_samples}")
print(f"Total audio: {total_audio_duration:.1f}s ({total_audio_duration/60:.1f}min)")
print(f"Total inference: {total_inference_time:.1f}s ({total_inference_time/60:.1f}min)")
print(f"Reference words: {total_ref_words}")
print(f"Hypothesis words: {total_hyp_words}")
print("-" * 70)
print(f"WER: {total_wer * 100:.2f}%")
print(f"CER: {total_cer * 100:.2f}%")
print(f"RTF: {total_rtf:.4f}")
if total_rtf > 0:
print(f" (1/RTF = {1/total_rtf:.1f}x realtime)")
print("-" * 70)
print(f"Median {unit[:3]} WER: {median_wer * 100:.2f}%")
print(f"P90 {unit[:3]} WER: {p90_wer * 100:.2f}%")
print(f"P95 {unit[:3]} WER: {p95_wer * 100:.2f}%")
print(f"Zero-WER {unit[:3]}: {zero_wer_count}/{n_samples} ({zero_wer_count/n_samples*100:.1f}%)")
print("-" * 70)
print(f"Alignment heads: {len(asr.alignment_heads)} heads, {len(asr.heads_by_layer)} layers")
print(f"Heads file: {heads_path or 'NONE (default heuristic)'}")
print(f"Model loaded in: {t_load_end - t_load_start:.2f}s")
print("=" * 70)
# H100 reference comparison
print("\nH100 PyTorch SimulStream+KV reference (chapter-grouped, bf=0.25):")
print(" 0.6B: WER 6.44%, RTF 0.109 (91 chapters, 602s)")
print(" 1.7B: WER 8.09%, RTF 0.117 (91 chapters, 602s)")
# Worst samples
worst = sorted(per_sample_results, key=lambda r: r["wer"], reverse=True)[:10]
print(f"\nTop 10 worst {unit}:")
for r in worst:
print(f" {r['id']}: WER={r['wer']*100:.1f}% dur={r['duration_s']:.1f}s rtf={r['rtf']:.3f}")
if r['wer'] > 0.5:
print(f" REF: {r['ref_norm'][:80]}")
print(f" HYP: {r['hyp_norm'][:80]}")
# Save JSON results
if args.output_json:
output = {
"model": f"Qwen3-ASR-{args.model_size.upper()}",
"backend": "mlx-simul-streaming",
"mode": mode,
"platform": "Apple M5 (32GB)",
"config": {
"border_fraction": args.border_fraction,
"chunk_seconds": args.chunk_seconds,
"chapter_grouped": args.chapter_grouped,
},
"n_samples": n_samples,
"total_audio_s": round(total_audio_duration, 2),
"total_inference_s": round(total_inference_time, 2),
"wer": round(total_wer, 6),
"cer": round(total_cer, 6),
"rtf": round(total_rtf, 6),
"median_wer": round(median_wer, 6),
"p90_wer": round(p90_wer, 6),
"p95_wer": round(p95_wer, 6),
"alignment_heads_count": len(asr.alignment_heads),
"alignment_heads_file": heads_path,
"per_sample": per_sample_results,
}
with open(args.output_json, "w") as f:
json.dump(output, f, indent=2)
logger.info("Results saved to %s", args.output_json)
if __name__ == "__main__":
main()

View File

@@ -1,124 +0,0 @@
#!/usr/bin/env python3
"""Standalone Voxtral benchmark — no whisperlivekit imports."""
import json, logging, re, time, wave, queue, threading
import numpy as np
import torch
logging.basicConfig(level=logging.WARNING)
for n in ["transformers","torch","httpx"]:
logging.getLogger(n).setLevel(logging.ERROR)
from jiwer import wer as compute_wer
from transformers import AutoProcessor, VoxtralRealtimeForConditionalGeneration, TextIteratorStreamer
def norm(t):
return re.sub(r' +', ' ', re.sub(r'[^a-z0-9 ]', ' ', t.lower())).strip()
def load_audio(path):
with wave.open(path, 'r') as wf:
return np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16).astype(np.float32) / 32768.0
# Load model
print("Loading Voxtral-Mini-4B...", flush=True)
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda:0",
)
print(f"Loaded, GPU: {torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
def transcribe_batch(audio_np):
"""Simple batch transcription (not streaming)."""
# Voxtral expects audio as input_features from processor
inputs = processor(
audio=audio_np, sampling_rate=16000, return_tensors="pt",
).to("cuda:0").to(torch.bfloat16)
t0 = time.perf_counter()
with torch.inference_mode():
generated = model.generate(**inputs, max_new_tokens=1024)
t1 = time.perf_counter()
text = processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
return text, t1 - t0
# 1. LibriSpeech test-clean
print("\n=== Voxtral / LibriSpeech test-clean ===", flush=True)
clean = json.load(open("/home/cloud/benchmark_data/metadata.json"))
wers = []; ta = tp = 0
for i, s in enumerate(clean):
audio = load_audio(s['path'])
hyp, pt = transcribe_batch(audio)
w = compute_wer(norm(s['reference']), norm(hyp))
wers.append(w); ta += s['duration']; tp += pt
if i < 3 or i % 20 == 0:
print(f" [{i}] {s['duration']:.1f}s RTF={pt/s['duration']:.2f} WER={w:.1%} | {hyp[:60]}", flush=True)
clean_wer = np.mean(wers); clean_rtf = tp/ta
print(f" CLEAN: WER {clean_wer:.2%}, RTF {clean_rtf:.3f} ({len(clean)} samples, {ta:.0f}s)")
# 2. LibriSpeech test-other
print("\n=== Voxtral / LibriSpeech test-other ===", flush=True)
other = json.load(open("/home/cloud/benchmark_data/metadata_other.json"))
wers2 = []; ta2 = tp2 = 0
for i, s in enumerate(other):
audio = load_audio(s['path'])
hyp, pt = transcribe_batch(audio)
w = compute_wer(norm(s['reference']), norm(hyp))
wers2.append(w); ta2 += s['duration']; tp2 += pt
if i < 3 or i % 20 == 0:
print(f" [{i}] {s['duration']:.1f}s RTF={pt/s['duration']:.2f} WER={w:.1%}", flush=True)
other_wer = np.mean(wers2); other_rtf = tp2/ta2
print(f" OTHER: WER {other_wer:.2%}, RTF {other_rtf:.3f} ({len(other)} samples, {ta2:.0f}s)")
# 3. ACL6060
print("\n=== Voxtral / ACL6060 ===", flush=True)
acl_results = []
for talk in ["110", "117", "268", "367", "590"]:
audio = load_audio(f"/home/cloud/acl6060_audio/2022.acl-long.{talk}.wav")
dur = len(audio) / 16000
gw = []
with open(f"/home/cloud/iwslt26-sst/inputs/en/acl6060.ts/gold-jsonl/2022.acl-long.{talk}.jsonl") as f:
for line in f:
gw.append(json.loads(line)["text"].strip())
gold = " ".join(gw)
# For long audio, process in 30s chunks
all_hyp = []
t0 = time.perf_counter()
chunk_size = 30 * 16000
for start in range(0, len(audio), chunk_size):
chunk = audio[start:start + chunk_size]
if len(chunk) < 1600: # skip very short tail
continue
hyp, _ = transcribe_batch(chunk)
all_hyp.append(hyp)
t1 = time.perf_counter()
full_hyp = " ".join(all_hyp)
w = compute_wer(norm(gold), norm(full_hyp))
rtf = (t1 - t0) / dur
acl_results.append({"talk": talk, "wer": w, "rtf": rtf, "dur": dur})
print(f" Talk {talk}: {dur:.0f}s, WER {w:.2%}, RTF {rtf:.3f}", flush=True)
acl_wer = np.mean([r["wer"] for r in acl_results])
acl_rtf = np.mean([r["rtf"] for r in acl_results])
print(f" ACL6060 AVERAGE: WER {acl_wer:.2%}, RTF {acl_rtf:.3f}")
# Summary
print(f"\n{'='*60}")
print(f" VOXTRAL BENCHMARK SUMMARY (H100 80GB)")
print(f"{'='*60}")
print(f" {'Dataset':>25} {'WER':>7} {'RTF':>7}")
print(f" {'-'*42}")
print(f" {'LibriSpeech clean':>25} {clean_wer:>6.2%} {clean_rtf:>7.3f}")
print(f" {'LibriSpeech other':>25} {other_wer:>6.2%} {other_rtf:>7.3f}")
print(f" {'ACL6060 (5 talks)':>25} {acl_wer:>6.2%} {acl_rtf:>7.3f}")
results = {
"clean": {"avg_wer": round(float(clean_wer), 4), "rtf": round(float(clean_rtf), 3)},
"other": {"avg_wer": round(float(other_wer), 4), "rtf": round(float(other_rtf), 3)},
"acl6060": {"avg_wer": round(float(acl_wer), 4), "avg_rtf": round(float(acl_rtf), 3),
"talks": [{k: (round(float(v), 4) if isinstance(v, (float, np.floating)) else v) for k, v in r.items()} for r in acl_results]},
}
json.dump(results, open("/home/cloud/bench_voxtral_results.json", "w"), indent=2)
print(f"\nSaved to /home/cloud/bench_voxtral_results.json")

View File

@@ -1,122 +0,0 @@
#!/usr/bin/env python3
"""Benchmark Voxtral via vLLM WebSocket /v1/realtime — proper streaming."""
import asyncio, json, base64, time, wave, re, os
import numpy as np
import websockets
import librosa
from jiwer import wer as compute_wer
MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
WS_URI = "ws://localhost:8000/v1/realtime"
def norm(t):
return re.sub(r' +', ' ', re.sub(r'[^a-z0-9 ]', ' ', t.lower())).strip()
async def transcribe(audio_path, max_tokens=4096):
audio, _ = librosa.load(audio_path, sr=16000, mono=True)
pcm16 = (audio * 32767).astype(np.int16).tobytes()
dur = len(audio) / 16000
t0 = time.time()
transcript = ""
first_token_time = None
async with websockets.connect(WS_URI, max_size=2**24) as ws:
await ws.recv() # session.created
await ws.send(json.dumps({"type": "session.update", "model": MODEL}))
await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) # signal ready
# Send audio in 4KB chunks
for i in range(0, len(pcm16), 4096):
await ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": base64.b64encode(pcm16[i:i+4096]).decode(),
}))
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
while True:
try:
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=120))
if msg["type"] == "transcription.delta":
d = msg.get("delta", "")
if d.strip() and first_token_time is None:
first_token_time = time.time() - t0
transcript += d
elif msg["type"] == "transcription.done":
transcript = msg.get("text", transcript)
break
elif msg["type"] == "error":
break
except asyncio.TimeoutError:
break
elapsed = time.time() - t0
return transcript.strip(), dur, elapsed / dur, first_token_time or elapsed
async def main():
# Warmup
print("Warmup...", flush=True)
await transcribe("/home/cloud/benchmark_data/librispeech_clean_0000.wav")
# LibriSpeech clean (full 91 samples)
print("\n=== Voxtral vLLM Realtime / LibriSpeech clean ===", flush=True)
clean = json.load(open("/home/cloud/benchmark_data/metadata.json"))
wers = []; ta = tp = 0
for i, s in enumerate(clean):
hyp, dur, rtf, fwl = await transcribe(s['path'])
w = compute_wer(norm(s['reference']), norm(hyp)) if hyp else 1.0
wers.append(w); ta += dur; tp += dur * rtf
if i < 3 or i % 20 == 0:
print(f" [{i}] {dur:.1f}s RTF={rtf:.3f} FWL={fwl:.2f}s WER={w:.1%} | {hyp[:60]}", flush=True)
clean_wer = np.mean(wers); clean_rtf = tp / ta
print(f" CLEAN ({len(clean)}): WER {clean_wer:.2%}, RTF {clean_rtf:.3f}\n", flush=True)
# LibriSpeech other (full 133 samples)
print("=== Voxtral vLLM Realtime / LibriSpeech other ===", flush=True)
other = json.load(open("/home/cloud/benchmark_data/metadata_other.json"))
wers2 = []; ta2 = tp2 = 0
for i, s in enumerate(other):
hyp, dur, rtf, fwl = await transcribe(s['path'])
w = compute_wer(norm(s['reference']), norm(hyp)) if hyp else 1.0
wers2.append(w); ta2 += dur; tp2 += dur * rtf
if i < 3 or i % 20 == 0:
print(f" [{i}] {dur:.1f}s RTF={rtf:.3f} WER={w:.1%}", flush=True)
other_wer = np.mean(wers2); other_rtf = tp2 / ta2
print(f" OTHER ({len(other)}): WER {other_wer:.2%}, RTF {other_rtf:.3f}\n", flush=True)
# ACL6060 talks
print("=== Voxtral vLLM Realtime / ACL6060 ===", flush=True)
acl = []
for talk in ["110", "117", "268", "367", "590"]:
gw = []
with open(f"/home/cloud/iwslt26-sst/inputs/en/acl6060.ts/gold-jsonl/2022.acl-long.{talk}.jsonl") as f:
for line in f: gw.append(json.loads(line)["text"].strip())
gold = " ".join(gw)
hyp, dur, rtf, fwl = await transcribe(f"/home/cloud/acl6060_audio/2022.acl-long.{talk}.wav")
w = compute_wer(norm(gold), norm(hyp)) if hyp else 1.0
acl.append({"talk": talk, "wer": round(float(w),4), "rtf": round(float(rtf),3), "dur": round(dur,1)})
print(f" Talk {talk}: {dur:.0f}s, WER {w:.2%}, RTF {rtf:.3f}, FWL {fwl:.2f}s", flush=True)
acl_wer = np.mean([r["wer"] for r in acl])
acl_rtf = np.mean([r["rtf"] for r in acl])
print(f" ACL6060 AVERAGE: WER {acl_wer:.2%}, RTF {acl_rtf:.3f}\n", flush=True)
# Summary
print(f"{'='*55}")
print(f" VOXTRAL vLLM REALTIME BENCHMARK (H100)")
print(f"{'='*55}")
print(f" LS clean ({len(clean)}): WER {clean_wer:.2%}, RTF {clean_rtf:.3f}")
print(f" LS other ({len(other)}): WER {other_wer:.2%}, RTF {other_rtf:.3f}")
print(f" ACL6060 (5): WER {acl_wer:.2%}, RTF {acl_rtf:.3f}")
results = {
"clean": {"avg_wer": round(float(clean_wer),4), "rtf": round(float(clean_rtf),3), "n": len(clean)},
"other": {"avg_wer": round(float(other_wer),4), "rtf": round(float(other_rtf),3), "n": len(other)},
"acl6060": {"avg_wer": round(float(acl_wer),4), "avg_rtf": round(float(acl_rtf),3), "talks": acl},
}
json.dump(results, open("/home/cloud/bench_voxtral_realtime_results.json", "w"), indent=2)
print(f"\n Saved to /home/cloud/bench_voxtral_realtime_results.json")
asyncio.run(main())

View File

@@ -9,9 +9,10 @@ import json
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
DIR = os.path.dirname(os.path.abspath(__file__))
@@ -113,7 +114,8 @@ def fig_scatter_acl6060():
label_off = [(10, -12), (10, 6), (10, 6), (10, 6)]
for (name, d, color, marker, sz), (lx, ly) in zip(pts, label_off):
wer = d["avg_wer"]; rtf = d["avg_rtf"]
wer = d["avg_wer"]
rtf = d["avg_rtf"]
ax.scatter(rtf, wer, s=sz, c=color, marker=marker,
edgecolors="white", linewidths=1.5, zorder=5)
ax.annotate(name, (rtf, wer), fontsize=9.5, fontweight="bold",
@@ -157,20 +159,26 @@ def fig_bars():
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
# WER
ax = axes[0]; w = 0.36
ax = axes[0]
w = 0.36
ax.bar(x - w/2, wer_c, w, color=cols, alpha=0.9, edgecolor="white", label="test-clean")
ax.bar(x + w/2, wer_o, w, color=cols_l, alpha=0.65, edgecolor="white", label="test-other")
ax.set_ylabel("WER %"); ax.set_title("Word Error Rate", fontweight="bold")
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.legend(fontsize=8); ax.grid(axis="y", alpha=0.15)
ax.set_ylabel("WER %")
ax.set_title("Word Error Rate", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.15)
for i, v in enumerate(wer_c):
ax.text(i - w/2, v + 0.2, f"{v:.1f}", ha="center", fontsize=7, fontweight="bold")
# RTF
ax = axes[1]
ax.bar(x, rtf_c, 0.55, color=cols, alpha=0.9, edgecolor="white")
ax.set_ylabel("RTF (lower = faster)"); ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.set_ylabel("RTF (lower = faster)")
ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.grid(axis="y", alpha=0.15)
for i, v in enumerate(rtf_c):
ax.text(i, v + 0.003, f"{v:.3f}", ha="center", fontsize=8, fontweight="bold")
@@ -178,8 +186,10 @@ def fig_bars():
# First-word latency
ax = axes[2]
ax.bar(x, fwl, 0.55, color=cols, alpha=0.9, edgecolor="white")
ax.set_ylabel("ms"); ax.set_title("First Word Latency", fontweight="bold")
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.set_ylabel("ms")
ax.set_title("First Word Latency", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
ax.grid(axis="y", alpha=0.15)
for i, v in enumerate(fwl):
ax.text(i, v + 8, f"{v}", ha="center", fontsize=8, fontweight="bold")
@@ -222,8 +232,10 @@ def fig_robustness():
ax.set_xlabel("WER % on test-clean")
ax.set_ylabel("WER % on test-other")
ax.set_title("Clean vs Noisy Robustness (H100 80 GB)", fontsize=13, fontweight="bold", pad=12)
ax.set_xlim(-0.3, 12); ax.set_ylim(-0.3, 12)
ax.set_aspect("equal"); ax.grid(True, alpha=0.12)
ax.set_xlim(-0.3, 12)
ax.set_ylim(-0.3, 12)
ax.set_aspect("equal")
ax.grid(True, alpha=0.12)
_save(fig, "robustness_clean_vs_other.png")
@@ -236,7 +248,8 @@ def fig_per_talk():
talks = DATA["acl6060"]["talks"]
fig, ax = plt.subplots(figsize=(9, 5))
x = np.arange(len(talks)); w = 0.35
x = np.arange(len(talks))
w = 0.35
bars_v = ax.bar(x - w/2, [v[t] for t in talks], w, color=COLORS["voxtral"],
edgecolor="white", label="Voxtral 4B (vLLM)")
@@ -254,8 +267,10 @@ def fig_per_talk():
ax.set_ylabel("WER %")
ax.set_title("Per-Talk WER — ACL6060 Conference Talks (H100 80 GB)",
fontsize=13, fontweight="bold", pad=12)
ax.set_xticks(x); ax.set_xticklabels([f"Talk {t}" for t in talks])
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.15)
ax.set_xticks(x)
ax.set_xticklabels([f"Talk {t}" for t in talks])
ax.legend(fontsize=9)
ax.grid(axis="y", alpha=0.15)
ax.set_ylim(0, 18)
_save(fig, "acl6060_per_talk.png")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Generate combined M5 vs H100 benchmark figure for WhisperLiveKit.
Produces a WER vs RTF scatter plot comparing Apple M5 (MLX) and
NVIDIA H100 results on LibriSpeech test-clean.
Note: M5 uses per-utterance evaluation (500 samples), while H100
uses chapter-grouped evaluation (91 chapters). Per-utterance WER
is typically lower because short utterances avoid long-range errors.
Run: python3 benchmarks/m5/generate_figures.py
"""
import json
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
DIR = os.path.dirname(os.path.abspath(__file__))
H100_DATA = json.load(open(os.path.join(DIR, "..", "h100", "results.json")))
M5_DATA = json.load(open(os.path.join(DIR, "results.json")))
# -- Style --
plt.rcParams.update({
"font.family": "sans-serif",
"font.size": 11,
"axes.spines.top": False,
"axes.spines.right": False,
})
COLORS = {
"whisper": "#d63031",
"qwen_b": "#6c5ce7",
"qwen_s": "#00b894",
"voxtral": "#fdcb6e",
"m5_qwen": "#0984e3",
}
def _save(fig, name):
path = os.path.join(DIR, name)
fig.savefig(path, dpi=180, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" saved: {name}")
def fig_m5_vs_h100():
"""WER vs RTF scatter: M5 (MLX) and H100 (CUDA) on LibriSpeech test-clean."""
h100 = H100_DATA["librispeech_clean"]["systems"]
m5 = M5_DATA["models"]
fig, ax = plt.subplots(figsize=(10, 7))
# Light green band for "good WER" zone
ax.axhspan(0, 5, color="#f0fff0", alpha=0.5, zorder=0)
# --- H100 points ---
h100_pts = [
("Whisper large-v3\n(H100, batch)", h100["whisper_large_v3_batch"], COLORS["whisper"], "h", 220),
("Qwen3 0.6B batch\n(H100)", h100["qwen3_0.6b_batch"], COLORS["qwen_b"], "h", 170),
("Qwen3 1.7B batch\n(H100)", h100["qwen3_1.7b_batch"], COLORS["qwen_b"], "h", 220),
("Voxtral 4B vLLM\n(H100)", h100["voxtral_4b_vllm_realtime"], COLORS["voxtral"], "D", 240),
("Qwen3 0.6B SimulStream+KV\n(H100)", h100["qwen3_0.6b_simulstream_kv"], COLORS["qwen_s"], "s", 200),
("Qwen3 1.7B SimulStream+KV\n(H100)", h100["qwen3_1.7b_simulstream_kv"], COLORS["qwen_s"], "s", 260),
]
h100_offsets = [(-55, 10), (-55, -22), (8, -18), (8, 10), (8, 10), (8, -18)]
for (name, d, color, marker, sz), (lx, ly) in zip(h100_pts, h100_offsets):
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
edgecolors="white", linewidths=1.5, zorder=5)
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
xytext=(lx, ly), textcoords="offset points",
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
# --- M5 points ---
m5_pts = [
("Qwen3 0.6B SimulStream\n(M5, MLX)", m5["qwen3-asr-0.6b-simul"], COLORS["m5_qwen"], "^", 260),
("Qwen3 1.7B SimulStream\n(M5, MLX)", m5["qwen3-asr-1.7b-simul"], COLORS["m5_qwen"], "^", 300),
]
m5_offsets = [(8, 8), (8, -18)]
for (name, d, color, marker, sz), (lx, ly) in zip(m5_pts, m5_offsets):
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
edgecolors="white", linewidths=1.5, zorder=6)
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
xytext=(lx, ly), textcoords="offset points",
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
# --- Connecting lines between same models on different hardware ---
# 0.6B: H100 SimulStream+KV -> M5 SimulStream
ax.plot([h100["qwen3_0.6b_simulstream_kv"]["rtf"], m5["qwen3-asr-0.6b-simul"]["rtf"]],
[h100["qwen3_0.6b_simulstream_kv"]["wer"], m5["qwen3-asr-0.6b-simul"]["wer"]],
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
# 1.7B: H100 SimulStream+KV -> M5 SimulStream
ax.plot([h100["qwen3_1.7b_simulstream_kv"]["rtf"], m5["qwen3-asr-1.7b-simul"]["rtf"]],
[h100["qwen3_1.7b_simulstream_kv"]["wer"], m5["qwen3-asr-1.7b-simul"]["wer"]],
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
# --- RTF = 1 line (real-time boundary) ---
ax.axvline(x=1.0, color="#e17055", linestyle=":", alpha=0.5, lw=1.5, zorder=1)
ax.text(1.02, 0.5, "real-time\nboundary", fontsize=8, color="#e17055",
fontstyle="italic", alpha=0.7, va="bottom")
# --- Methodology note ---
ax.text(0.98, 0.02,
"H100: chapter-grouped WER (91 chapters) | M5: per-utterance WER (500 samples)\n"
"Per-utterance WER is typically lower -- results are not directly comparable.",
transform=ax.transAxes, fontsize=7.5, color="#666",
ha="right", va="bottom", fontstyle="italic",
bbox=dict(boxstyle="round,pad=0.3", fc="#fff9e6", ec="#ddd", alpha=0.9))
ax.set_xlabel("RTF (lower = faster)")
ax.set_ylabel("WER % (lower = better)")
ax.set_title("H100 vs M5 (MLX) -- Qwen3-ASR on LibriSpeech test-clean",
fontsize=13, fontweight="bold", pad=12)
ax.set_xlim(-0.01, 1.1)
ax.set_ylim(-0.5, 10)
ax.grid(True, alpha=0.12)
legend = [
mpatches.Patch(color=COLORS["whisper"], label="Whisper large-v3 (H100)"),
mpatches.Patch(color=COLORS["qwen_b"], label="Qwen3-ASR batch (H100)"),
mpatches.Patch(color=COLORS["qwen_s"], label="Qwen3 SimulStream+KV (H100)"),
mpatches.Patch(color=COLORS["voxtral"], label="Voxtral 4B vLLM (H100)"),
mpatches.Patch(color=COLORS["m5_qwen"], label="Qwen3 SimulStream (M5, MLX)"),
plt.Line2D([0], [0], marker="h", color="w", mfc="gray", ms=8, label="Batch mode"),
plt.Line2D([0], [0], marker="s", color="w", mfc="gray", ms=8, label="Streaming (H100)"),
plt.Line2D([0], [0], marker="^", color="w", mfc="gray", ms=8, label="Streaming (M5)"),
]
ax.legend(handles=legend, fontsize=8, loc="upper right", framealpha=0.85, ncol=2)
_save(fig, "m5_vs_h100_wer_rtf.png")
if __name__ == "__main__":
print("Generating M5 vs H100 benchmark figure...")
fig_m5_vs_h100()
print("Done!")

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

View File

@@ -0,0 +1,9 @@
{
"platform": "Apple M5 (32GB RAM, MLX fp16)",
"dataset": "LibriSpeech test-clean",
"methodology": "per-utterance (500 samples)",
"models": {
"qwen3-asr-0.6b-simul": {"wer": 3.30, "rtf": 0.263},
"qwen3-asr-1.7b-simul": {"wer": 4.07, "rtf": 0.944}
}
}

View File

@@ -418,7 +418,31 @@ class AudioProcessor:
logger.info("Transcription processor task finished.")
async def _update_diarization_state(self, diarization_segments) -> None:
"""Push new diarization segments into the shared state."""
if not diarization_segments:
return
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization.extend(diarization_segments)
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
async def _drain_diarization_buffer(self) -> None:
"""Process all remaining audio in the diarization buffer.
Sortformer-style backends accumulate audio in an internal buffer and
process one chunk per ``diarize()`` call, returning ``[]`` when the
buffer is too short. This helper loops until the buffer is fully
consumed.
"""
while True:
diarization_segments = await self.diarization.diarize()
if not diarization_segments:
break
await self._update_diarization_state(diarization_segments)
async def diarization_processor(self) -> None:
has_buffer = hasattr(self.diarization, 'buffer_audio')
while True:
try:
item = await get_all_from_queue(self.diarization_queue)
@@ -429,16 +453,26 @@ class AudioProcessor:
self.diarization.insert_silence(item.duration)
continue
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
if has_buffer:
await self._drain_diarization_buffer()
else:
# Cumulative backends (e.g. Diart): replace, not extend
diarization_segments = await self.diarization.diarize()
diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
# Drain any remaining audio in the buffer before exiting
if has_buffer:
try:
await self._drain_diarization_buffer()
except Exception as e:
logger.warning(f"Exception draining diarization buffer: {e}")
logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None:

View File

@@ -59,6 +59,7 @@ def detect_available_backends() -> List[str]:
try:
import mlx.core # noqa: F401
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model # noqa: F401
backends.append("voxtral-mlx")
except ImportError:

View File

@@ -233,6 +233,7 @@ def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
def _decode_audio(audio_bytes: bytes) -> tuple:
import io
import soundfile as sf
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
return np.array(audio_array, dtype=np.float32), sr

View File

@@ -103,7 +103,6 @@ def print_report(report: BenchmarkReport, out: TextIO = sys.stderr) -> None:
# Per-language breakdown
wer_by_lang = report.wer_by_language()
rtf_by_lang = report.rtf_by_language()
if len(wer_by_lang) > 1:
w(f"\n {BOLD}By Language{RESET}\n")
w(f" {'' * 40}\n")

View File

@@ -46,7 +46,6 @@ class BenchmarkRunner:
async def run(self) -> BenchmarkReport:
"""Run the full benchmark suite and return a report."""
from whisperlivekit.metrics import compute_wer
from whisperlivekit.test_harness import TestHarness
# Get samples
samples = get_benchmark_samples(

View File

@@ -1,116 +0,0 @@
"""
Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline.
Converts streaming ASRToken output from SimulStreaming into the JSONL
format expected by the AlignAtt MT agent (iwslt26-sst).
Output format (one JSON per line):
{"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0}
Where:
- text: the emitted word/phrase
- emission_time: wall-clock time when the word was emitted (for compute-aware eval)
- speech_time: timestamp in the audio (for compute-unaware eval)
- is_final: whether this is the last word of a segment/silence boundary
"""
import json
import time
from typing import List, TextIO
from whisperlivekit.timed_objects import ASRToken
class CascadeBridge:
"""Converts ASRToken stream to JSONL for the MT agent."""
def __init__(self, output_file: TextIO = None):
self.output_file = output_file
self.start_time = time.time()
self.entries: List[dict] = []
def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False):
"""Emit a batch of tokens from the STT."""
wall_clock = time.time() - self.start_time
for i, token in enumerate(tokens):
entry = {
"text": token.text.strip(),
"emission_time": round(wall_clock, 3),
"speech_time": round(token.start, 3),
"is_final": is_final and (i == len(tokens) - 1),
}
self.entries.append(entry)
if self.output_file:
self.output_file.write(json.dumps(entry) + "\n")
self.output_file.flush()
def get_entries(self) -> List[dict]:
return self.entries
def get_text(self) -> str:
"""Get the full transcribed text."""
return " ".join(e["text"] for e in self.entries if e["text"])
def save(self, path: str):
"""Save all entries to a JSONL file."""
with open(path, "w") as f:
for entry in self.entries:
f.write(json.dumps(entry) + "\n")
def run_stt_to_jsonl(
audio_path: str,
output_path: str,
model_id: str = "Qwen/Qwen3-ASR-0.6B",
alignment_heads_path: str = None,
border_fraction: float = 0.20,
language: str = "en",
chunk_sec: float = 1.0,
):
"""Run STT on an audio file and save JSONL output for the MT agent.
This is the main entry point for the cascade: audio file → JSONL.
"""
import wave
import numpy as np
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor
# Load audio
with wave.open(audio_path, 'r') as wf:
audio = np.frombuffer(
wf.readframes(wf.getnframes()), dtype=np.int16
).astype(np.float32) / 32768.0
# Initialize STT
asr = Qwen3SimulKVASR(
model_dir=model_id,
lan=language,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
proc = Qwen3SimulKVOnlineProcessor(asr)
bridge = CascadeBridge()
# Stream audio in chunks
chunk_samples = int(chunk_sec * 16000)
offset = 0
stream_time = 0.0
while offset < len(audio):
chunk = audio[offset:offset + chunk_samples]
stream_time += len(chunk) / 16000
proc.insert_audio_chunk(chunk, stream_time)
words, _ = proc.process_iter(is_last=False)
if words:
bridge.emit_tokens(words, is_final=False)
offset += chunk_samples
# Final flush
final_words, _ = proc.finish()
if final_words:
bridge.emit_tokens(final_words, is_final=True)
# Save
bridge.save(output_path)
return bridge

View File

@@ -386,7 +386,8 @@ def cmd_models():
# --- System info ---
print(f"\n Platform: {platform.system()} {platform.machine()}")
print(f" Accelerator: {_gpu_info()}")
print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}")
_ffmpeg_status = "found" if _check_ffmpeg() else "\033[31mNOT FOUND\033[0m (required)"
print(f" ffmpeg: {_ffmpeg_status}")
# --- Model catalog ---
print("\n Models:\n")
@@ -419,7 +420,7 @@ def cmd_models():
)
# --- Quick start ---
print(f"\n Quick start:\n")
print("\n Quick start:\n")
if is_apple_silicon:
print(" wlk run voxtral-mlx # Best streaming on Apple Silicon")
print(" wlk run large-v3-turbo # Best quality/speed balance")
@@ -806,7 +807,7 @@ async def _run_bench_new(parsed, languages, categories):
on_progress=on_progress,
)
print(f"\n Downloading benchmark samples (cached after first run)...",
print("\n Downloading benchmark samples (cached after first run)...",
file=sys.stderr)
report = await runner.run()

View File

@@ -121,6 +121,15 @@ class TranscriptionEngine:
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend == "qwen3-mlx-simul":
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingASR
self.tokenizer = None
self.asr = Qwen3MLXSimulStreamingASR(
**transcription_common_params,
alignment_heads_path=config.custom_alignment_heads,
border_fraction=getattr(config, 'border_fraction', 0.15),
)
logger.info("Using Qwen3 MLX SimulStreaming backend")
elif config.backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
self.tokenizer = None
@@ -247,6 +256,9 @@ def online_factory(args, asr, language=None):
if backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
return Qwen3SimulKVOnlineProcessor(asr)
if backend == "qwen3-mlx-simul":
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingOnlineProcessor
return Qwen3MLXSimulStreamingOnlineProcessor(asr)
if backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
return Qwen3MLXOnlineProcessor(asr)

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-mlx-simul", "qwen3-simul", "vllm-realtime"],
help="Select the ASR backend implementation. Use 'qwen3-mlx-simul' for Qwen3-ASR SimulStreaming on Apple Silicon (MLX). Use 'qwen3-mlx' for Qwen3-ASR LocalAgreement on MLX. Use 'qwen3-simul' for Qwen3-ASR SimulStreaming (PyTorch). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
)
parser.add_argument(
"--no-vac",

View File

@@ -0,0 +1,759 @@
"""
Qwen3-ASR SimulStreaming (AlignAtt) on MLX for Apple Silicon.
Uses the ``mlx_qwen3_asr`` library for model loading, audio encoding, and
tokenization. Implements the AlignAtt border-distance policy by monkey-
patching ``TextAttention.__call__`` on alignment layers to capture Q (with
RoPE) during autoregressive decode steps, then computing ``Q @ K_audio^T``
from the KV cache to find the most-attended audio frame.
This is the MLX equivalent of ``qwen3_simul.py`` (PyTorch) which uses
``register_forward_hook`` for the same purpose.
"""
import json
import logging
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16_000
# Model size aliases (same as qwen3_mlx_asr.py)
QWEN3_MLX_MODEL_MAPPING = {
"base": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"1.7b": "Qwen/Qwen3-ASR-1.7B",
"0.6b": "Qwen/Qwen3-ASR-0.6B",
}
# Whisper language codes -> Qwen3 canonical language names
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class Qwen3MLXSimulConfig:
language: str = "auto"
alignment_heads_path: Optional[str] = None
border_fraction: float = 0.15
rewind_fraction: float = 0.12
audio_min_len: float = 3.0
audio_max_len: float = 15.0
max_context_tokens: int = 30
max_alignment_heads: int = 20
# ---------------------------------------------------------------------------
# Per-session state
# ---------------------------------------------------------------------------
@dataclass
class _SessionState:
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
last_attend_frame: int = -15
committed_word_count: int = 0
committed_token_ids: List[int] = field(default_factory=list)
detected_language: Optional[str] = None
last_infer_samples: int = 0
# Pending partial word from previous _infer() call.
# When a border stops mid-word (e.g., "Vill" from "Villard"),
# the partial is held here and prepended to the next call's output.
pending_partial: str = ""
pending_partial_start: Optional[float] = None
# Whether the first emitted token of this call is a continuation of the
# previous call's last word (no leading space → subword continuation).
first_emit_is_continuation: bool = False
# ---------------------------------------------------------------------------
# Shared model holder
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingASR:
"""Loads the Qwen3-ASR model via ``mlx_qwen3_asr`` once and keeps it
alive for the lifetime of the server. Shared across sessions."""
sep = ""
SAMPLING_RATE = SAMPLE_RATE
def __init__(
self,
model_size: str = None,
model_dir: str = None,
model_path: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
lora_path: Optional[str] = None,
min_chunk_size: float = 0.1,
direct_english_translation: bool = False,
**kwargs,
):
import mlx.core as mx
import mlx_qwen3_asr
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3MLXSimulConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
# Resolve model path
resolved = model_dir or model_path
if not resolved:
size = (model_size or "base").lower()
if "/" in size or size.startswith("."):
resolved = size
else:
resolved = QWEN3_MLX_MODEL_MAPPING.get(size, "Qwen/Qwen3-ASR-0.6B")
t0 = time.time()
logger.info("Loading Qwen3-ASR MLX model '%s' for SimulStreaming ...", resolved)
self.model, self._config = mlx_qwen3_asr.load_model(resolved, dtype=mx.float16)
logger.info("Model loaded in %.2fs", time.time() - t0)
# Tokenizer
tok_path = getattr(self.model, "_resolved_model_path", None) or resolved
self.tokenizer = mlx_qwen3_asr.tokenizer.Tokenizer(str(tok_path))
# Architecture info
text_cfg = self._config.text_config
self.num_layers = text_cfg.num_hidden_layers
self.num_heads = text_cfg.num_attention_heads
self.num_kv_heads = text_cfg.num_key_value_heads
self.head_dim = text_cfg.head_dim
self.gqa_ratio = self.num_heads // self.num_kv_heads
self.audio_token_id = self._config.audio_token_id
logger.info(
"Qwen3-ASR arch: %d layers x %d heads (%d kv), head_dim=%d, GQA=%d",
self.num_layers, self.num_heads, self.num_kv_heads,
self.head_dim, self.gqa_ratio,
)
# Alignment heads
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
self.heads_by_layer = {}
for layer_idx, head_idx in self.alignment_heads:
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
self.backend_choice = "qwen3-mlx-simul"
# Warmup
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
self._warmup(audio)
def _load_alignment_heads(
self, path: Optional[str],
) -> List[Tuple[int, int]]:
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info(
"Loaded top %d alignment heads from %s (of %d total)",
len(heads), path, len(all_heads),
)
return heads
# Default heuristic: last quarter of layers, all heads
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning(
"No alignment heads file. Using default heuristic: "
"%d heads from layers %d-%d.",
len(default_heads), start_layer, self.num_layers - 1,
)
return default_heads[:max_heads]
def _warmup(self, audio: np.ndarray):
import mlx.core as mx
try:
from mlx_qwen3_asr.audio import compute_features
audio = audio[:SAMPLE_RATE * 2]
mel, feat_lens = compute_features(audio)
mel = mel.astype(mx.float16)
audio_features, _ = self.model.audio_tower(mel, feat_lens)
n_audio = int(audio_features.shape[1])
prompt = self.tokenizer.build_prompt_tokens(n_audio, language="English")
input_ids = mx.array([prompt])
positions = mx.arange(input_ids.shape[1])[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
cache = self.model.create_cache()
logits = self.model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
logger.info("Qwen3 MLX SimulStreaming warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
pass # all work in the online processor
# ---------------------------------------------------------------------------
# Attention capture via wrapper replacement
# ---------------------------------------------------------------------------
class _AttnCaptureWrapper:
"""Wraps a TextAttention module to capture alignment scores during decode.
Replaces ``layer.self_attn`` with this wrapper. On decode steps (L=1),
recomputes Q with RoPE, reads cached K from the audio region, computes
``Q @ K_audio^T`` for alignment heads, and stores the argmax frame in
``capture["step_frames"]``.
Python dunder resolution (``__call__``) goes through the *class*, not the
instance, so monkey-patching ``attn.__call__`` on an ``nn.Module`` does
not work. This wrapper class defines its own ``__call__`` and delegates
everything else to the wrapped module via ``__getattr__``.
"""
def __init__(self, original, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture):
# Store in __dict__ directly to avoid triggering __getattr__
self.__dict__["_original"] = original
self.__dict__["_layer_idx"] = layer_idx
self.__dict__["_head_indices"] = head_indices
self.__dict__["_gqa_ratio"] = gqa_ratio
self.__dict__["_audio_start"] = audio_start
self.__dict__["_audio_end"] = audio_end
self.__dict__["_capture"] = capture
def __call__(self, x, cos, sin, mask=None, cache=None, layer_idx=0):
import mlx.core as mx
from mlx_qwen3_asr.mrope import apply_rotary_pos_emb
orig = self.__dict__["_original"]
B, L, _ = x.shape
if L == 1 and cache is not None:
li = self.__dict__["_layer_idx"]
h_indices = self.__dict__["_head_indices"]
gqa = self.__dict__["_gqa_ratio"]
a_start = self.__dict__["_audio_start"]
a_end = self.__dict__["_audio_end"]
cap = self.__dict__["_capture"]
# Recompute Q with RoPE (cheap: single token)
q = orig.q_proj(x)
q = q.reshape(B, L, orig.num_heads, orig.head_dim)
q = orig.q_norm(q)
q = q.transpose(0, 2, 1, 3) # (B, H, 1, D)
q_rope, _ = apply_rotary_pos_emb(q, q, cos, sin)
# K from cache (already has RoPE baked in from cache.update)
k_cached = cache.keys[li]
if k_cached is not None and a_end <= k_cached.shape[2]:
for h_idx in h_indices:
kv_h = h_idx // gqa
q_h = q_rope[0, h_idx, 0] # (head_dim,)
k_audio = k_cached[0, kv_h, a_start:a_end] # (n_audio, D)
scores = k_audio @ q_h # (n_audio,)
frame = int(mx.argmax(scores).item())
cap["step_frames"].append(frame)
return orig(x, cos, sin, mask=mask, cache=cache, layer_idx=layer_idx)
def __getattr__(self, name):
return getattr(self.__dict__["_original"], name)
def _install_alignment_hooks(model, heads_by_layer, gqa_ratio, audio_start, audio_end, capture):
"""Replace ``self_attn`` on alignment layers with capture wrappers.
Returns a list of ``(layer_idx, original_attn)`` for later restoration.
"""
originals = []
for layer_idx, head_indices in heads_by_layer.items():
if layer_idx >= len(model.model.layers):
continue
layer = model.model.layers[layer_idx]
orig_attn = layer.self_attn
wrapper = _AttnCaptureWrapper(
orig_attn, layer_idx, head_indices, gqa_ratio,
audio_start, audio_end, capture,
)
layer.self_attn = wrapper
originals.append((layer_idx, orig_attn))
return originals
def _remove_alignment_hooks(model, originals):
"""Restore original self_attn modules."""
for layer_idx, orig_attn in originals:
model.model.layers[layer_idx].self_attn = orig_attn
# ---------------------------------------------------------------------------
# Per-session online processor
# ---------------------------------------------------------------------------
class Qwen3MLXSimulStreamingOnlineProcessor:
"""Per-session processor implementing AlignAtt on MLX.
Same interface as other online processors:
insert_audio_chunk / process_iter / get_buffer / start_silence /
end_silence / finish / warmup / new_speaker.
"""
SAMPLING_RATE = SAMPLE_RATE
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3MLXSimulStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
self.state = _SessionState()
# -- properties expected by AudioProcessor --
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
# Trim if too long
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
# -- main processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
# Throttle: at least 1s of new audio
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
if not is_last and new_samples < int(1.0 * self.SAMPLING_RATE):
return [], self.end
try:
words = self._infer(is_last)
except Exception as e:
logger.exception("Qwen3 MLX SimulStreaming inference error: %s", e)
return [], self.end
# Update the budget marker after _infer() so the decoder can size its
# generation budget using the real amount of fresh audio.
self.state.last_infer_samples = len(self.state.audio_buffer)
if not words:
return [], self.end
self.buffer = []
return words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run one inference cycle with alignment-head-based stopping."""
import mlx.core as mx
from mlx_qwen3_asr.audio import compute_features
from mlx_qwen3_asr.generate import _detect_repetition
asr = self.asr
state = self.state
model = asr.model
# 1. Encode audio
mel, feat_lens = compute_features(state.audio_buffer)
mel = mel.astype(mx.float16)
audio_features, _ = model.audio_tower(mel, feat_lens)
n_audio_tokens = int(audio_features.shape[1])
mx.eval(audio_features)
if n_audio_tokens == 0:
return []
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# 2. Build prompt tokens
lan = asr.cfg.language
language = None
if lan and lan != "auto":
language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
prompt_tokens = asr.tokenizer.build_prompt_tokens(
n_audio_tokens=n_audio_tokens,
language=language,
)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
prompt_tokens.extend(ctx)
input_ids = mx.array([prompt_tokens])
seq_len = input_ids.shape[1]
# 3. Find audio token range
audio_positions = [
i for i, t in enumerate(prompt_tokens) if t == asr.audio_token_id
]
if not audio_positions:
return []
audio_start = audio_positions[0]
audio_end = audio_positions[-1] + 1
# 4. MRoPE position IDs
positions = mx.arange(seq_len, dtype=mx.int32)[None, :]
position_ids = mx.stack([positions, positions, positions], axis=1)
# 5. Prefill
cache = model.create_cache(max_seq_len=seq_len + 120)
logits = model.prefill(input_ids, audio_features, position_ids, cache)
mx.eval(logits)
# 6. Install alignment hooks
capture = {"step_frames": []}
originals = _install_alignment_hooks(
model, asr.heads_by_layer, asr.gqa_ratio,
audio_start, audio_end, capture,
)
# 7. Decode loop with border-distance policy
eos_ids = set(asr.tokenizer.EOS_TOKEN_IDS)
per_step_frames: List[List[int]] = []
last_attend_frame = state.last_attend_frame
border_stop_step: Optional[int] = None
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
# Max tokens: ~6 tokens/sec of speech + margin
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
if is_last:
max_tokens = min(int(audio_duration * 6) + 10, 120)
else:
max_tokens = min(int(max(new_audio_secs, 1.0) * 6) + 5, 40)
token = int(mx.argmax(logits.reshape(-1)).item())
generated = [token]
try:
for step in range(1, max_tokens):
if token in eos_ids:
break
if _detect_repetition(generated):
break
next_ids = mx.array([[token]])
pos_val = seq_len + step - 1
next_pos = mx.array([[[pos_val], [pos_val], [pos_val]]], dtype=mx.int32)
logits = model.step(next_ids, next_pos, cache, validate_input_ids=False)
mx.eval(logits)
token = int(mx.argmax(logits.reshape(-1)).item())
generated.append(token)
# Collect frames from this step
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
capture["step_frames"] = []
# Border-distance check (skip first 3 steps)
if (not is_last
and border_stop_step is None
and len(per_step_frames) >= 3):
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
# Rewind check
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
break
last_attend_frame = attended
# Border check
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
break
# Periodic eval to prevent graph buildup
if step % 8 == 0:
mx.eval(cache.keys[-1])
finally:
_remove_alignment_hooks(model, originals)
# Flush remaining frames
if capture["step_frames"]:
per_step_frames.append(capture["step_frames"])
state.last_attend_frame = last_attend_frame
# 8. Process generated tokens
# Remove trailing EOS
while generated and generated[-1] in eos_ids:
generated.pop()
num_gen = len(generated)
if num_gen == 0:
return []
raw_text = asr.tokenizer.decode(generated)
logger.info(
"SimulStreaming raw: %d tokens (border_stop=%s), text=%r",
num_gen, border_stop_step, raw_text[:100],
)
# 9. Strip metadata prefix ("language English<asr_text>...")
from mlx_qwen3_asr.tokenizer import parse_asr_output
detected_lang, clean_text = parse_asr_output(
raw_text,
user_language=language,
)
# Find how many tokens to skip for metadata
metadata_offset = 0
asr_text_tokens = asr.tokenizer.encode("<asr_text>")
asr_text_id = asr_text_tokens[0] if asr_text_tokens else None
if asr_text_id is not None:
for i in range(min(num_gen, 10)):
if generated[i] == asr_text_id:
metadata_offset = i + 1
break
if metadata_offset > 0:
generated = generated[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Detect language
if state.detected_language is None and detected_lang and detected_lang != "unknown":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
detected_lang, detected_lang.lower(),
)
logger.info("Auto-detected language: %s", state.detected_language)
# 10. Determine how many tokens to emit
step_frames = [f for f in per_step_frames if f]
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
if emit_up_to <= 0:
return []
emitted_ids = generated[:emit_up_to]
if emit_up_to <= 0:
return []
# 11. Build timestamped words
words = self._build_timestamped_words(
emitted_ids, step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
# Update state
state.committed_word_count += len(words)
state.committed_token_ids.extend(emitted_ids)
return words
def _build_timestamped_words(
self,
generated_ids: List[int],
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
"""Build timestamped ASRToken list from generated tokens and
alignment-head captured frames."""
state = self.state
asr = self.asr
# Per-token attended frame (median of head votes)
per_token_frame: List[Optional[int]] = []
for step_idx in range(emit_up_to):
if step_idx < len(step_frames) and step_frames[step_idx]:
frames = sorted(step_frames[step_idx])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
# Decode full text, split into words
full_text = asr.tokenizer.decode(generated_ids[:emit_up_to])
text_words = full_text.split()
# Map words to frames proportionally
all_frames = [f for f in per_token_frame if f is not None]
word_frame_pairs = []
for wi, word in enumerate(text_words):
if all_frames:
frac = wi / max(len(text_words), 1)
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
word_frame_pairs.append((word, frame))
# Convert to ASRToken
tokens = []
for i, (text, frame) in enumerate(word_frame_pairs):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(word_frame_pairs), 1)) * audio_duration
+ state.cumulative_time_offset
)
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
# -- silence / speaker / lifecycle --
def start_silence(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
gap_silence = np.zeros(gap_len, dtype=np.float32)
self.state.audio_buffer = np.append(
self.state.audio_buffer, gap_silence,
)
else:
self.state = _SessionState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker):
self.process_iter(is_last=True)
self.state = _SessionState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = _SessionState()
logger.info("Qwen3 MLX SimulStreaming processor warmed up")
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = _SessionState()
def finish(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end

View File

@@ -622,9 +622,6 @@ class Qwen3SimulStreamingOnlineProcessor:
thinker = asr.model.thinker
try:
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
n_audio_tokens = audio_embeds.shape[0]
@@ -719,7 +716,6 @@ class Qwen3SimulStreamingOnlineProcessor:
return [], self.end
logger.info("Running SimulStreaming inference on %.2fs of audio (%.2fs new)", audio_duration, new_samples / self.SAMPLING_RATE)
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
timestamped_words = self._infer(is_last)
@@ -727,6 +723,10 @@ class Qwen3SimulStreamingOnlineProcessor:
logger.exception("Qwen3 SimulStreaming inference error: %s", e)
return [], self.end
# Update the decode-budget marker after inference so _infer() sees the
# true amount of newly arrived audio.
self.state.last_infer_samples = len(self.state.audio_buffer)
logger.info("SimulStreaming produced %d words", len(timestamped_words))
if not timestamped_words:
return [], self.end

View File

@@ -158,7 +158,9 @@ class Qwen3SimulKVASR:
_patch_transformers_compat()
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
@@ -210,7 +212,18 @@ class Qwen3SimulKVASR:
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
if "alignment_heads_compact" in data:
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
elif "token_alignment_heads" in data:
all_heads = [
(int(h["layer"]), int(h["head"]))
for h in data["token_alignment_heads"]
]
else:
raise KeyError(
"alignment_heads_compact/token_alignment_heads not found in "
f"{path}"
)
heads = all_heads[:max_heads]
logger.info("Loaded top %d alignment heads from %s", len(heads), path)
return heads
@@ -333,6 +346,21 @@ class Qwen3SimulKVOnlineProcessor:
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
@staticmethod
def _normalize_audio_embeds(audio_embeds: torch.Tensor) -> torch.Tensor:
"""Keep cached audio embeddings in a consistent 2D layout."""
if audio_embeds.dim() == 3:
if audio_embeds.shape[0] != 1:
raise ValueError(
f"Unexpected batched audio embeds shape: {tuple(audio_embeds.shape)}",
)
audio_embeds = audio_embeds[0]
if audio_embeds.dim() == 1:
audio_embeds = audio_embeds.unsqueeze(0)
if audio_embeds.dim() != 2:
raise ValueError(f"Unexpected audio embeds shape: {tuple(audio_embeds.shape)}")
return audio_embeds
def _encode_audio(self) -> Tuple[torch.Tensor, int]:
"""Encode full audio buffer, with caching for stable windows."""
asr = self.asr
@@ -364,8 +392,7 @@ class Qwen3SimulKVOnlineProcessor:
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
audio_embeds = self._normalize_audio_embeds(audio_embeds)
stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
@@ -389,8 +416,7 @@ class Qwen3SimulKVOnlineProcessor:
tail_embeds = asr.model.thinker.get_audio_features(
tail_features, feature_attention_mask=tail_mask,
)
if tail_embeds.dim() == 3:
tail_embeds = tail_embeds[0]
tail_embeds = self._normalize_audio_embeds(tail_embeds)
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
else:
audio_embeds = cached_prefix
@@ -398,11 +424,10 @@ class Qwen3SimulKVOnlineProcessor:
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
audio_embeds = self._normalize_audio_embeds(audio_embeds)
# Update cache
cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0]
cache.embeddings = audio_embeds.unsqueeze(0)
cache.encoded_samples = len(state.audio_buffer)
cache.encoded_mel_frames = total_mel_frames
stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
@@ -418,9 +443,6 @@ class Qwen3SimulKVOnlineProcessor:
state = self.state
thinker = asr.model.thinker
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
n_audio_tokens = audio_embeds.shape[0]
@@ -482,8 +504,6 @@ class Qwen3SimulKVOnlineProcessor:
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
return [], self.end
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
timestamped_words = self._infer(is_last)
except Exception as e:
@@ -491,6 +511,11 @@ class Qwen3SimulKVOnlineProcessor:
self.state.reset_kv()
return [], self.end
# Advance the decode budget marker only after inference. Updating this
# before _infer() makes new_audio_secs collapse to zero inside the
# decoder loop and artificially caps generation to the 1-second path.
self.state.last_infer_samples = len(self.state.audio_buffer)
if not timestamped_words:
return [], self.end
@@ -529,7 +554,6 @@ class Qwen3SimulKVOnlineProcessor:
use_cache=True,
)
kv_cache = out.past_key_values
prompt_len = input_ids.shape[1]
# Step 4: Greedy decode with alignment head stopping
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
@@ -653,7 +677,6 @@ class Qwen3SimulKVOnlineProcessor:
return []
# Strip metadata prefix (<asr_text> token)
all_generated = torch.tensor(generated_ids, device=asr.device)
num_gen = len(generated_ids)
asr_text_id = asr.asr_text_token_id
metadata_offset = 0