mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-01 10:24:46 +00:00
Compare commits
9 Commits
benchmarks
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fc341683a | ||
|
|
3f233dc36c | ||
|
|
db526ded34 | ||
|
|
3e5d8c5820 | ||
|
|
b102e12943 | ||
|
|
7aa3b764bd | ||
|
|
a422e604ae | ||
|
|
e14b913807 | ||
|
|
3b7a2fcc87 |
460
benchmark_mlx_simul.py
Normal file
460
benchmark_mlx_simul.py
Normal file
@@ -0,0 +1,460 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark Qwen3-ASR MLX SimulStreaming on LibriSpeech test-clean.
|
||||
|
||||
Measures:
|
||||
- Word Error Rate (WER) via jiwer
|
||||
- Real-Time Factor (RTF) = total_inference_time / total_audio_duration
|
||||
- Per-utterance stats
|
||||
|
||||
Usage:
|
||||
# Per-utterance simul-streaming (default)
|
||||
python benchmark_mlx_simul.py --model-size 0.6b
|
||||
|
||||
# Single-shot (batch-like, no streaming chunking)
|
||||
python benchmark_mlx_simul.py --model-size 0.6b --single-shot
|
||||
|
||||
# Quick test with 100 utterances
|
||||
python benchmark_mlx_simul.py --model-size 0.6b --max-utterances 100
|
||||
|
||||
# Chapter-grouped (matching H100 benchmark methodology)
|
||||
python benchmark_mlx_simul.py --model-size 0.6b --chapter-grouped
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from jiwer import cer as compute_cer
|
||||
from jiwer import wer as compute_wer
|
||||
|
||||
# Add WhisperLiveKit to path
|
||||
WLKIT_DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(WLKIT_DIR))
|
||||
|
||||
from whisperlivekit.qwen3_mlx_simul import ( # noqa: E402
|
||||
Qwen3MLXSimulStreamingASR,
|
||||
Qwen3MLXSimulStreamingOnlineProcessor,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("benchmark")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
SAMPLE_RATE = 16_000
|
||||
|
||||
# Alignment heads paths
|
||||
ALIGNMENT_HEADS = {
|
||||
"0.6b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_0.6B.json"),
|
||||
"1.7b": str(WLKIT_DIR / "scripts" / "alignment_heads_qwen3_asr_1.7B_v2.json"),
|
||||
}
|
||||
|
||||
|
||||
def load_librispeech_utterances(data_dir: str, max_utterances: int = 0):
|
||||
"""Load LibriSpeech utterances: yields (utt_id, audio_np, reference_text, duration_s)."""
|
||||
data_path = Path(data_dir)
|
||||
trans_files = sorted(data_path.rglob("*.trans.txt"))
|
||||
|
||||
count = 0
|
||||
for trans_file in trans_files:
|
||||
chapter_dir = trans_file.parent
|
||||
with open(trans_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(" ", 1)
|
||||
utt_id = parts[0]
|
||||
ref_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
flac_path = chapter_dir / f"{utt_id}.flac"
|
||||
if not flac_path.exists():
|
||||
logger.warning("Missing FLAC: %s", flac_path)
|
||||
continue
|
||||
|
||||
audio, sr = sf.read(str(flac_path), dtype="float32")
|
||||
if sr != SAMPLE_RATE:
|
||||
import librosa
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
|
||||
|
||||
duration = len(audio) / SAMPLE_RATE
|
||||
yield utt_id, audio, ref_text, duration
|
||||
|
||||
count += 1
|
||||
if max_utterances > 0 and count >= max_utterances:
|
||||
return
|
||||
|
||||
|
||||
def load_librispeech_chapters(data_dir: str):
|
||||
"""Load LibriSpeech grouped by speaker-chapter.
|
||||
|
||||
Concatenates all utterances within each speaker/chapter into one long audio.
|
||||
Returns list of (chapter_id, audio_np, reference_text, duration_s).
|
||||
"""
|
||||
data_path = Path(data_dir)
|
||||
trans_files = sorted(data_path.rglob("*.trans.txt"))
|
||||
|
||||
chapters = []
|
||||
for trans_file in trans_files:
|
||||
chapter_dir = trans_file.parent
|
||||
chapter_id = chapter_dir.name
|
||||
speaker_id = chapter_dir.parent.name
|
||||
full_id = f"{speaker_id}-{chapter_id}"
|
||||
|
||||
audios = []
|
||||
refs = []
|
||||
with open(trans_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(" ", 1)
|
||||
utt_id = parts[0]
|
||||
ref_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
flac_path = chapter_dir / f"{utt_id}.flac"
|
||||
if not flac_path.exists():
|
||||
continue
|
||||
|
||||
audio, sr = sf.read(str(flac_path), dtype="float32")
|
||||
if sr != SAMPLE_RATE:
|
||||
import librosa
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
|
||||
|
||||
audios.append(audio)
|
||||
refs.append(ref_text)
|
||||
|
||||
if audios:
|
||||
# Concatenate with 0.5s silence between utterances
|
||||
silence = np.zeros(int(0.5 * SAMPLE_RATE), dtype=np.float32)
|
||||
combined = []
|
||||
for j, a in enumerate(audios):
|
||||
if j > 0:
|
||||
combined.append(silence)
|
||||
combined.append(a)
|
||||
combined_audio = np.concatenate(combined)
|
||||
combined_ref = " ".join(refs)
|
||||
duration = len(combined_audio) / SAMPLE_RATE
|
||||
chapters.append((full_id, combined_audio, combined_ref, duration))
|
||||
|
||||
return chapters
|
||||
|
||||
|
||||
def transcribe_simul(asr, audio, chunk_seconds=2.0):
|
||||
"""Transcribe using SimulStreaming with chunked audio feed.
|
||||
|
||||
Returns (transcription_text, inference_time_seconds).
|
||||
"""
|
||||
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
|
||||
chunk_size = int(chunk_seconds * SAMPLE_RATE)
|
||||
total_samples = len(audio)
|
||||
offset = 0
|
||||
all_tokens = []
|
||||
|
||||
t0 = time.perf_counter()
|
||||
|
||||
while offset < total_samples:
|
||||
end = min(offset + chunk_size, total_samples)
|
||||
chunk = audio[offset:end]
|
||||
stream_time = end / SAMPLE_RATE
|
||||
|
||||
processor.insert_audio_chunk(chunk, stream_time)
|
||||
|
||||
is_last = (end >= total_samples)
|
||||
tokens, _ = processor.process_iter(is_last=is_last)
|
||||
if tokens:
|
||||
all_tokens.extend(tokens)
|
||||
offset = end
|
||||
|
||||
# Final flush
|
||||
final_tokens, _ = processor.finish()
|
||||
if final_tokens:
|
||||
all_tokens.extend(final_tokens)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
inference_time = t1 - t0
|
||||
|
||||
text = "".join(t.text for t in all_tokens).strip()
|
||||
return text, inference_time
|
||||
|
||||
|
||||
def transcribe_single_shot(asr, audio):
|
||||
"""Transcribe by feeding all audio at once (batch-like).
|
||||
|
||||
Returns (transcription_text, inference_time_seconds).
|
||||
"""
|
||||
processor = Qwen3MLXSimulStreamingOnlineProcessor(asr)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
|
||||
duration = len(audio) / SAMPLE_RATE
|
||||
processor.insert_audio_chunk(audio, duration)
|
||||
all_tokens, _ = processor.process_iter(is_last=True)
|
||||
|
||||
# Flush
|
||||
final_tokens, _ = processor.finish()
|
||||
if final_tokens:
|
||||
all_tokens.extend(final_tokens)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
inference_time = t1 - t0
|
||||
|
||||
text = "".join(t.text for t in all_tokens).strip()
|
||||
return text, inference_time
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text for WER computation: uppercase, strip punctuation."""
|
||||
text = text.upper()
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark Qwen3-ASR MLX SimulStreaming")
|
||||
parser.add_argument("--model-size", default="0.6b", choices=["0.6b", "1.7b"],
|
||||
help="Model size (default: 0.6b)")
|
||||
parser.add_argument("--max-utterances", type=int, default=0,
|
||||
help="Max utterances to process (0=all). Ignored in chapter mode.")
|
||||
parser.add_argument("--librispeech-dir", default="/tmp/LibriSpeech/test-clean",
|
||||
help="Path to LibriSpeech test-clean directory")
|
||||
parser.add_argument("--single-shot", action="store_true",
|
||||
help="Feed entire audio at once instead of streaming chunks")
|
||||
parser.add_argument("--chunk-seconds", type=float, default=2.0,
|
||||
help="Chunk size in seconds for simul-streaming (default: 2.0)")
|
||||
parser.add_argument("--border-fraction", type=float, default=0.25,
|
||||
help="Border fraction for AlignAtt stopping (default: 0.25, matching H100 config)")
|
||||
parser.add_argument("--chapter-grouped", action="store_true",
|
||||
help="Group utterances by speaker-chapter (matching H100 methodology)")
|
||||
parser.add_argument("--output-json", default=None,
|
||||
help="Save per-utterance results to JSON file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check alignment heads
|
||||
heads_path = ALIGNMENT_HEADS.get(args.model_size)
|
||||
if heads_path and os.path.exists(heads_path):
|
||||
logger.info("Using alignment heads: %s", heads_path)
|
||||
with open(heads_path) as f:
|
||||
heads_data = json.load(f)
|
||||
n_heads = len(heads_data.get("alignment_heads_compact", []))
|
||||
logger.info(" Loaded %d alignment heads for border detection", n_heads)
|
||||
else:
|
||||
heads_path = None
|
||||
logger.warning("No alignment heads file found for %s! Using default heuristic.",
|
||||
args.model_size)
|
||||
|
||||
# Load model
|
||||
logger.info("Loading Qwen3-ASR-%s MLX SimulStreaming model...", args.model_size.upper())
|
||||
t_load_start = time.perf_counter()
|
||||
asr = Qwen3MLXSimulStreamingASR(
|
||||
model_size=args.model_size,
|
||||
lan="en",
|
||||
alignment_heads_path=heads_path,
|
||||
border_fraction=args.border_fraction,
|
||||
)
|
||||
t_load_end = time.perf_counter()
|
||||
logger.info("Model loaded in %.2fs", t_load_end - t_load_start)
|
||||
|
||||
# Verify alignment heads
|
||||
logger.info("Alignment heads active: %d heads across %d layers",
|
||||
len(asr.alignment_heads), len(asr.heads_by_layer))
|
||||
if asr.alignment_heads:
|
||||
layers = sorted(asr.heads_by_layer.keys())
|
||||
logger.info(" Active layers: %s", layers[:10])
|
||||
logger.info(" First 5 heads: %s", asr.alignment_heads[:5])
|
||||
|
||||
logger.info("Config: border_fraction=%.2f, chunk_seconds=%.1f",
|
||||
args.border_fraction, args.chunk_seconds)
|
||||
|
||||
# Warmup
|
||||
logger.info("Running warmup inference...")
|
||||
dummy_audio = np.random.randn(SAMPLE_RATE * 3).astype(np.float32) * 0.01
|
||||
if args.single_shot:
|
||||
_, warmup_time = transcribe_single_shot(asr, dummy_audio)
|
||||
else:
|
||||
_, warmup_time = transcribe_simul(asr, dummy_audio, args.chunk_seconds)
|
||||
logger.info("Warmup done in %.2fs", warmup_time)
|
||||
|
||||
# Determine mode
|
||||
mode = "single-shot" if args.single_shot else "simul-streaming"
|
||||
if args.chapter_grouped:
|
||||
mode += " (chapter-grouped)"
|
||||
|
||||
logger.info("Starting benchmark: model=%s, mode=%s, bf=%.2f, chunk=%.1fs",
|
||||
args.model_size, mode, args.border_fraction, args.chunk_seconds)
|
||||
logger.info("LibriSpeech dir: %s", args.librispeech_dir)
|
||||
|
||||
# Load data
|
||||
if args.chapter_grouped:
|
||||
samples = load_librispeech_chapters(args.librispeech_dir)
|
||||
logger.info("Loaded %d speaker-chapters", len(samples))
|
||||
else:
|
||||
samples = list(load_librispeech_utterances(
|
||||
args.librispeech_dir, args.max_utterances
|
||||
))
|
||||
logger.info("Loaded %d utterances", len(samples))
|
||||
|
||||
# Run benchmark
|
||||
references = []
|
||||
hypotheses = []
|
||||
per_sample_results = []
|
||||
total_audio_duration = 0.0
|
||||
total_inference_time = 0.0
|
||||
|
||||
for i, (sample_id, audio, ref_text, duration) in enumerate(samples):
|
||||
if args.single_shot:
|
||||
hyp_text, infer_time = transcribe_single_shot(asr, audio)
|
||||
else:
|
||||
hyp_text, infer_time = transcribe_simul(asr, audio, args.chunk_seconds)
|
||||
|
||||
ref_norm = normalize_text(ref_text)
|
||||
hyp_norm = normalize_text(hyp_text)
|
||||
|
||||
# Per-sample WER
|
||||
if ref_norm:
|
||||
sample_wer = compute_wer(ref_norm, hyp_norm)
|
||||
else:
|
||||
sample_wer = 0.0
|
||||
|
||||
total_audio_duration += duration
|
||||
total_inference_time += infer_time
|
||||
|
||||
references.append(ref_norm)
|
||||
hypotheses.append(hyp_norm)
|
||||
|
||||
result = {
|
||||
"id": sample_id,
|
||||
"ref": ref_text,
|
||||
"hyp": hyp_text,
|
||||
"ref_norm": ref_norm,
|
||||
"hyp_norm": hyp_norm,
|
||||
"duration_s": round(duration, 3),
|
||||
"infer_time_s": round(infer_time, 3),
|
||||
"rtf": round(infer_time / duration, 4) if duration > 0 else 0,
|
||||
"wer": round(sample_wer, 4),
|
||||
}
|
||||
per_sample_results.append(result)
|
||||
|
||||
# Progress logging
|
||||
if (i + 1) % 50 == 0 or (i + 1) <= 5:
|
||||
running_wer = compute_wer(references, hypotheses)
|
||||
running_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
|
||||
logger.info(
|
||||
"[%d/%d] id=%s dur=%.1fs infer=%.2fs rtf=%.3f wer=%.1f%% "
|
||||
"| running: wer=%.2f%% rtf=%.3f",
|
||||
i + 1, len(samples), sample_id, duration, infer_time,
|
||||
infer_time / duration if duration > 0 else 0,
|
||||
sample_wer * 100, running_wer * 100, running_rtf,
|
||||
)
|
||||
|
||||
# Show first few transcriptions
|
||||
if i < 3:
|
||||
logger.info(" REF: %s", ref_text[:120])
|
||||
logger.info(" HYP: %s", hyp_text[:120])
|
||||
|
||||
# Final results
|
||||
n_samples = len(references)
|
||||
if n_samples == 0:
|
||||
logger.error("No samples processed!")
|
||||
return
|
||||
|
||||
total_wer = compute_wer(references, hypotheses)
|
||||
total_cer = compute_cer(references, hypotheses)
|
||||
total_rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else 0
|
||||
|
||||
total_ref_words = sum(len(r.split()) for r in references)
|
||||
total_hyp_words = sum(len(h.split()) for h in hypotheses)
|
||||
|
||||
wers = [r["wer"] for r in per_sample_results]
|
||||
wers_sorted = sorted(wers)
|
||||
median_wer = wers_sorted[len(wers_sorted) // 2]
|
||||
p90_wer = wers_sorted[int(len(wers_sorted) * 0.9)]
|
||||
p95_wer = wers_sorted[int(len(wers_sorted) * 0.95)]
|
||||
zero_wer_count = sum(1 for w in wers if w == 0.0)
|
||||
|
||||
unit = "chapters" if args.chapter_grouped else "utterances"
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"BENCHMARK RESULTS: Qwen3-ASR-{args.model_size.upper()} MLX SimulStreaming")
|
||||
print(f"Mode: {mode}")
|
||||
print(f"Config: border_fraction={args.border_fraction}, chunk={args.chunk_seconds}s")
|
||||
print("=" * 70)
|
||||
print(f"Samples ({unit}): {n_samples}")
|
||||
print(f"Total audio: {total_audio_duration:.1f}s ({total_audio_duration/60:.1f}min)")
|
||||
print(f"Total inference: {total_inference_time:.1f}s ({total_inference_time/60:.1f}min)")
|
||||
print(f"Reference words: {total_ref_words}")
|
||||
print(f"Hypothesis words: {total_hyp_words}")
|
||||
print("-" * 70)
|
||||
print(f"WER: {total_wer * 100:.2f}%")
|
||||
print(f"CER: {total_cer * 100:.2f}%")
|
||||
print(f"RTF: {total_rtf:.4f}")
|
||||
if total_rtf > 0:
|
||||
print(f" (1/RTF = {1/total_rtf:.1f}x realtime)")
|
||||
print("-" * 70)
|
||||
print(f"Median {unit[:3]} WER: {median_wer * 100:.2f}%")
|
||||
print(f"P90 {unit[:3]} WER: {p90_wer * 100:.2f}%")
|
||||
print(f"P95 {unit[:3]} WER: {p95_wer * 100:.2f}%")
|
||||
print(f"Zero-WER {unit[:3]}: {zero_wer_count}/{n_samples} ({zero_wer_count/n_samples*100:.1f}%)")
|
||||
print("-" * 70)
|
||||
print(f"Alignment heads: {len(asr.alignment_heads)} heads, {len(asr.heads_by_layer)} layers")
|
||||
print(f"Heads file: {heads_path or 'NONE (default heuristic)'}")
|
||||
print(f"Model loaded in: {t_load_end - t_load_start:.2f}s")
|
||||
print("=" * 70)
|
||||
|
||||
# H100 reference comparison
|
||||
print("\nH100 PyTorch SimulStream+KV reference (chapter-grouped, bf=0.25):")
|
||||
print(" 0.6B: WER 6.44%, RTF 0.109 (91 chapters, 602s)")
|
||||
print(" 1.7B: WER 8.09%, RTF 0.117 (91 chapters, 602s)")
|
||||
|
||||
# Worst samples
|
||||
worst = sorted(per_sample_results, key=lambda r: r["wer"], reverse=True)[:10]
|
||||
print(f"\nTop 10 worst {unit}:")
|
||||
for r in worst:
|
||||
print(f" {r['id']}: WER={r['wer']*100:.1f}% dur={r['duration_s']:.1f}s rtf={r['rtf']:.3f}")
|
||||
if r['wer'] > 0.5:
|
||||
print(f" REF: {r['ref_norm'][:80]}")
|
||||
print(f" HYP: {r['hyp_norm'][:80]}")
|
||||
|
||||
# Save JSON results
|
||||
if args.output_json:
|
||||
output = {
|
||||
"model": f"Qwen3-ASR-{args.model_size.upper()}",
|
||||
"backend": "mlx-simul-streaming",
|
||||
"mode": mode,
|
||||
"platform": "Apple M5 (32GB)",
|
||||
"config": {
|
||||
"border_fraction": args.border_fraction,
|
||||
"chunk_seconds": args.chunk_seconds,
|
||||
"chapter_grouped": args.chapter_grouped,
|
||||
},
|
||||
"n_samples": n_samples,
|
||||
"total_audio_s": round(total_audio_duration, 2),
|
||||
"total_inference_s": round(total_inference_time, 2),
|
||||
"wer": round(total_wer, 6),
|
||||
"cer": round(total_cer, 6),
|
||||
"rtf": round(total_rtf, 6),
|
||||
"median_wer": round(median_wer, 6),
|
||||
"p90_wer": round(p90_wer, 6),
|
||||
"p95_wer": round(p95_wer, 6),
|
||||
"alignment_heads_count": len(asr.alignment_heads),
|
||||
"alignment_heads_file": heads_path,
|
||||
"per_sample": per_sample_results,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(output, f, indent=2)
|
||||
logger.info("Results saved to %s", args.output_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Standalone Voxtral benchmark — no whisperlivekit imports."""
|
||||
import json, logging, re, time, wave, queue, threading
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
for n in ["transformers","torch","httpx"]:
|
||||
logging.getLogger(n).setLevel(logging.ERROR)
|
||||
|
||||
from jiwer import wer as compute_wer
|
||||
from transformers import AutoProcessor, VoxtralRealtimeForConditionalGeneration, TextIteratorStreamer
|
||||
|
||||
def norm(t):
|
||||
return re.sub(r' +', ' ', re.sub(r'[^a-z0-9 ]', ' ', t.lower())).strip()
|
||||
|
||||
def load_audio(path):
|
||||
with wave.open(path, 'r') as wf:
|
||||
return np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
# Load model
|
||||
print("Loading Voxtral-Mini-4B...", flush=True)
|
||||
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda:0",
|
||||
)
|
||||
print(f"Loaded, GPU: {torch.cuda.memory_allocated()/1e9:.1f} GB", flush=True)
|
||||
|
||||
def transcribe_batch(audio_np):
|
||||
"""Simple batch transcription (not streaming)."""
|
||||
# Voxtral expects audio as input_features from processor
|
||||
inputs = processor(
|
||||
audio=audio_np, sampling_rate=16000, return_tensors="pt",
|
||||
).to("cuda:0").to(torch.bfloat16)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
generated = model.generate(**inputs, max_new_tokens=1024)
|
||||
t1 = time.perf_counter()
|
||||
|
||||
text = processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
|
||||
return text, t1 - t0
|
||||
|
||||
# 1. LibriSpeech test-clean
|
||||
print("\n=== Voxtral / LibriSpeech test-clean ===", flush=True)
|
||||
clean = json.load(open("/home/cloud/benchmark_data/metadata.json"))
|
||||
wers = []; ta = tp = 0
|
||||
for i, s in enumerate(clean):
|
||||
audio = load_audio(s['path'])
|
||||
hyp, pt = transcribe_batch(audio)
|
||||
w = compute_wer(norm(s['reference']), norm(hyp))
|
||||
wers.append(w); ta += s['duration']; tp += pt
|
||||
if i < 3 or i % 20 == 0:
|
||||
print(f" [{i}] {s['duration']:.1f}s RTF={pt/s['duration']:.2f} WER={w:.1%} | {hyp[:60]}", flush=True)
|
||||
clean_wer = np.mean(wers); clean_rtf = tp/ta
|
||||
print(f" CLEAN: WER {clean_wer:.2%}, RTF {clean_rtf:.3f} ({len(clean)} samples, {ta:.0f}s)")
|
||||
|
||||
# 2. LibriSpeech test-other
|
||||
print("\n=== Voxtral / LibriSpeech test-other ===", flush=True)
|
||||
other = json.load(open("/home/cloud/benchmark_data/metadata_other.json"))
|
||||
wers2 = []; ta2 = tp2 = 0
|
||||
for i, s in enumerate(other):
|
||||
audio = load_audio(s['path'])
|
||||
hyp, pt = transcribe_batch(audio)
|
||||
w = compute_wer(norm(s['reference']), norm(hyp))
|
||||
wers2.append(w); ta2 += s['duration']; tp2 += pt
|
||||
if i < 3 or i % 20 == 0:
|
||||
print(f" [{i}] {s['duration']:.1f}s RTF={pt/s['duration']:.2f} WER={w:.1%}", flush=True)
|
||||
other_wer = np.mean(wers2); other_rtf = tp2/ta2
|
||||
print(f" OTHER: WER {other_wer:.2%}, RTF {other_rtf:.3f} ({len(other)} samples, {ta2:.0f}s)")
|
||||
|
||||
# 3. ACL6060
|
||||
print("\n=== Voxtral / ACL6060 ===", flush=True)
|
||||
acl_results = []
|
||||
for talk in ["110", "117", "268", "367", "590"]:
|
||||
audio = load_audio(f"/home/cloud/acl6060_audio/2022.acl-long.{talk}.wav")
|
||||
dur = len(audio) / 16000
|
||||
gw = []
|
||||
with open(f"/home/cloud/iwslt26-sst/inputs/en/acl6060.ts/gold-jsonl/2022.acl-long.{talk}.jsonl") as f:
|
||||
for line in f:
|
||||
gw.append(json.loads(line)["text"].strip())
|
||||
gold = " ".join(gw)
|
||||
|
||||
# For long audio, process in 30s chunks
|
||||
all_hyp = []
|
||||
t0 = time.perf_counter()
|
||||
chunk_size = 30 * 16000
|
||||
for start in range(0, len(audio), chunk_size):
|
||||
chunk = audio[start:start + chunk_size]
|
||||
if len(chunk) < 1600: # skip very short tail
|
||||
continue
|
||||
hyp, _ = transcribe_batch(chunk)
|
||||
all_hyp.append(hyp)
|
||||
t1 = time.perf_counter()
|
||||
|
||||
full_hyp = " ".join(all_hyp)
|
||||
w = compute_wer(norm(gold), norm(full_hyp))
|
||||
rtf = (t1 - t0) / dur
|
||||
acl_results.append({"talk": talk, "wer": w, "rtf": rtf, "dur": dur})
|
||||
print(f" Talk {talk}: {dur:.0f}s, WER {w:.2%}, RTF {rtf:.3f}", flush=True)
|
||||
|
||||
acl_wer = np.mean([r["wer"] for r in acl_results])
|
||||
acl_rtf = np.mean([r["rtf"] for r in acl_results])
|
||||
print(f" ACL6060 AVERAGE: WER {acl_wer:.2%}, RTF {acl_rtf:.3f}")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f" VOXTRAL BENCHMARK SUMMARY (H100 80GB)")
|
||||
print(f"{'='*60}")
|
||||
print(f" {'Dataset':>25} {'WER':>7} {'RTF':>7}")
|
||||
print(f" {'-'*42}")
|
||||
print(f" {'LibriSpeech clean':>25} {clean_wer:>6.2%} {clean_rtf:>7.3f}")
|
||||
print(f" {'LibriSpeech other':>25} {other_wer:>6.2%} {other_rtf:>7.3f}")
|
||||
print(f" {'ACL6060 (5 talks)':>25} {acl_wer:>6.2%} {acl_rtf:>7.3f}")
|
||||
|
||||
results = {
|
||||
"clean": {"avg_wer": round(float(clean_wer), 4), "rtf": round(float(clean_rtf), 3)},
|
||||
"other": {"avg_wer": round(float(other_wer), 4), "rtf": round(float(other_rtf), 3)},
|
||||
"acl6060": {"avg_wer": round(float(acl_wer), 4), "avg_rtf": round(float(acl_rtf), 3),
|
||||
"talks": [{k: (round(float(v), 4) if isinstance(v, (float, np.floating)) else v) for k, v in r.items()} for r in acl_results]},
|
||||
}
|
||||
json.dump(results, open("/home/cloud/bench_voxtral_results.json", "w"), indent=2)
|
||||
print(f"\nSaved to /home/cloud/bench_voxtral_results.json")
|
||||
@@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark Voxtral via vLLM WebSocket /v1/realtime — proper streaming."""
|
||||
import asyncio, json, base64, time, wave, re, os
|
||||
import numpy as np
|
||||
import websockets
|
||||
import librosa
|
||||
from jiwer import wer as compute_wer
|
||||
|
||||
MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||
WS_URI = "ws://localhost:8000/v1/realtime"
|
||||
|
||||
def norm(t):
|
||||
return re.sub(r' +', ' ', re.sub(r'[^a-z0-9 ]', ' ', t.lower())).strip()
|
||||
|
||||
async def transcribe(audio_path, max_tokens=4096):
|
||||
audio, _ = librosa.load(audio_path, sr=16000, mono=True)
|
||||
pcm16 = (audio * 32767).astype(np.int16).tobytes()
|
||||
dur = len(audio) / 16000
|
||||
|
||||
t0 = time.time()
|
||||
transcript = ""
|
||||
first_token_time = None
|
||||
|
||||
async with websockets.connect(WS_URI, max_size=2**24) as ws:
|
||||
await ws.recv() # session.created
|
||||
await ws.send(json.dumps({"type": "session.update", "model": MODEL}))
|
||||
await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) # signal ready
|
||||
|
||||
# Send audio in 4KB chunks
|
||||
for i in range(0, len(pcm16), 4096):
|
||||
await ws.send(json.dumps({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(pcm16[i:i+4096]).decode(),
|
||||
}))
|
||||
|
||||
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=120))
|
||||
if msg["type"] == "transcription.delta":
|
||||
d = msg.get("delta", "")
|
||||
if d.strip() and first_token_time is None:
|
||||
first_token_time = time.time() - t0
|
||||
transcript += d
|
||||
elif msg["type"] == "transcription.done":
|
||||
transcript = msg.get("text", transcript)
|
||||
break
|
||||
elif msg["type"] == "error":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
elapsed = time.time() - t0
|
||||
return transcript.strip(), dur, elapsed / dur, first_token_time or elapsed
|
||||
|
||||
async def main():
|
||||
# Warmup
|
||||
print("Warmup...", flush=True)
|
||||
await transcribe("/home/cloud/benchmark_data/librispeech_clean_0000.wav")
|
||||
|
||||
# LibriSpeech clean (full 91 samples)
|
||||
print("\n=== Voxtral vLLM Realtime / LibriSpeech clean ===", flush=True)
|
||||
clean = json.load(open("/home/cloud/benchmark_data/metadata.json"))
|
||||
wers = []; ta = tp = 0
|
||||
for i, s in enumerate(clean):
|
||||
hyp, dur, rtf, fwl = await transcribe(s['path'])
|
||||
w = compute_wer(norm(s['reference']), norm(hyp)) if hyp else 1.0
|
||||
wers.append(w); ta += dur; tp += dur * rtf
|
||||
if i < 3 or i % 20 == 0:
|
||||
print(f" [{i}] {dur:.1f}s RTF={rtf:.3f} FWL={fwl:.2f}s WER={w:.1%} | {hyp[:60]}", flush=True)
|
||||
clean_wer = np.mean(wers); clean_rtf = tp / ta
|
||||
print(f" CLEAN ({len(clean)}): WER {clean_wer:.2%}, RTF {clean_rtf:.3f}\n", flush=True)
|
||||
|
||||
# LibriSpeech other (full 133 samples)
|
||||
print("=== Voxtral vLLM Realtime / LibriSpeech other ===", flush=True)
|
||||
other = json.load(open("/home/cloud/benchmark_data/metadata_other.json"))
|
||||
wers2 = []; ta2 = tp2 = 0
|
||||
for i, s in enumerate(other):
|
||||
hyp, dur, rtf, fwl = await transcribe(s['path'])
|
||||
w = compute_wer(norm(s['reference']), norm(hyp)) if hyp else 1.0
|
||||
wers2.append(w); ta2 += dur; tp2 += dur * rtf
|
||||
if i < 3 or i % 20 == 0:
|
||||
print(f" [{i}] {dur:.1f}s RTF={rtf:.3f} WER={w:.1%}", flush=True)
|
||||
other_wer = np.mean(wers2); other_rtf = tp2 / ta2
|
||||
print(f" OTHER ({len(other)}): WER {other_wer:.2%}, RTF {other_rtf:.3f}\n", flush=True)
|
||||
|
||||
# ACL6060 talks
|
||||
print("=== Voxtral vLLM Realtime / ACL6060 ===", flush=True)
|
||||
acl = []
|
||||
for talk in ["110", "117", "268", "367", "590"]:
|
||||
gw = []
|
||||
with open(f"/home/cloud/iwslt26-sst/inputs/en/acl6060.ts/gold-jsonl/2022.acl-long.{talk}.jsonl") as f:
|
||||
for line in f: gw.append(json.loads(line)["text"].strip())
|
||||
gold = " ".join(gw)
|
||||
|
||||
hyp, dur, rtf, fwl = await transcribe(f"/home/cloud/acl6060_audio/2022.acl-long.{talk}.wav")
|
||||
w = compute_wer(norm(gold), norm(hyp)) if hyp else 1.0
|
||||
acl.append({"talk": talk, "wer": round(float(w),4), "rtf": round(float(rtf),3), "dur": round(dur,1)})
|
||||
print(f" Talk {talk}: {dur:.0f}s, WER {w:.2%}, RTF {rtf:.3f}, FWL {fwl:.2f}s", flush=True)
|
||||
|
||||
acl_wer = np.mean([r["wer"] for r in acl])
|
||||
acl_rtf = np.mean([r["rtf"] for r in acl])
|
||||
print(f" ACL6060 AVERAGE: WER {acl_wer:.2%}, RTF {acl_rtf:.3f}\n", flush=True)
|
||||
|
||||
# Summary
|
||||
print(f"{'='*55}")
|
||||
print(f" VOXTRAL vLLM REALTIME BENCHMARK (H100)")
|
||||
print(f"{'='*55}")
|
||||
print(f" LS clean ({len(clean)}): WER {clean_wer:.2%}, RTF {clean_rtf:.3f}")
|
||||
print(f" LS other ({len(other)}): WER {other_wer:.2%}, RTF {other_rtf:.3f}")
|
||||
print(f" ACL6060 (5): WER {acl_wer:.2%}, RTF {acl_rtf:.3f}")
|
||||
|
||||
results = {
|
||||
"clean": {"avg_wer": round(float(clean_wer),4), "rtf": round(float(clean_rtf),3), "n": len(clean)},
|
||||
"other": {"avg_wer": round(float(other_wer),4), "rtf": round(float(other_rtf),3), "n": len(other)},
|
||||
"acl6060": {"avg_wer": round(float(acl_wer),4), "avg_rtf": round(float(acl_rtf),3), "talks": acl},
|
||||
}
|
||||
json.dump(results, open("/home/cloud/bench_voxtral_realtime_results.json", "w"), indent=2)
|
||||
print(f"\n Saved to /home/cloud/bench_voxtral_realtime_results.json")
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -9,9 +9,10 @@ import json
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -113,7 +114,8 @@ def fig_scatter_acl6060():
|
||||
label_off = [(10, -12), (10, 6), (10, 6), (10, 6)]
|
||||
|
||||
for (name, d, color, marker, sz), (lx, ly) in zip(pts, label_off):
|
||||
wer = d["avg_wer"]; rtf = d["avg_rtf"]
|
||||
wer = d["avg_wer"]
|
||||
rtf = d["avg_rtf"]
|
||||
ax.scatter(rtf, wer, s=sz, c=color, marker=marker,
|
||||
edgecolors="white", linewidths=1.5, zorder=5)
|
||||
ax.annotate(name, (rtf, wer), fontsize=9.5, fontweight="bold",
|
||||
@@ -157,20 +159,26 @@ def fig_bars():
|
||||
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
|
||||
|
||||
# WER
|
||||
ax = axes[0]; w = 0.36
|
||||
ax = axes[0]
|
||||
w = 0.36
|
||||
ax.bar(x - w/2, wer_c, w, color=cols, alpha=0.9, edgecolor="white", label="test-clean")
|
||||
ax.bar(x + w/2, wer_o, w, color=cols_l, alpha=0.65, edgecolor="white", label="test-other")
|
||||
ax.set_ylabel("WER %"); ax.set_title("Word Error Rate", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.legend(fontsize=8); ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_ylabel("WER %")
|
||||
ax.set_title("Word Error Rate", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(wer_c):
|
||||
ax.text(i - w/2, v + 0.2, f"{v:.1f}", ha="center", fontsize=7, fontweight="bold")
|
||||
|
||||
# RTF
|
||||
ax = axes[1]
|
||||
ax.bar(x, rtf_c, 0.55, color=cols, alpha=0.9, edgecolor="white")
|
||||
ax.set_ylabel("RTF (lower = faster)"); ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.set_ylabel("RTF (lower = faster)")
|
||||
ax.set_title("Real-Time Factor (test-clean)", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(rtf_c):
|
||||
ax.text(i, v + 0.003, f"{v:.3f}", ha="center", fontsize=8, fontweight="bold")
|
||||
@@ -178,8 +186,10 @@ def fig_bars():
|
||||
# First-word latency
|
||||
ax = axes[2]
|
||||
ax.bar(x, fwl, 0.55, color=cols, alpha=0.9, edgecolor="white")
|
||||
ax.set_ylabel("ms"); ax.set_title("First Word Latency", fontweight="bold")
|
||||
ax.set_xticks(x); ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.set_ylabel("ms")
|
||||
ax.set_title("First Word Latency", fontweight="bold")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, fontsize=7.5, rotation=25, ha="right")
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
for i, v in enumerate(fwl):
|
||||
ax.text(i, v + 8, f"{v}", ha="center", fontsize=8, fontweight="bold")
|
||||
@@ -222,8 +232,10 @@ def fig_robustness():
|
||||
ax.set_xlabel("WER % on test-clean")
|
||||
ax.set_ylabel("WER % on test-other")
|
||||
ax.set_title("Clean vs Noisy Robustness (H100 80 GB)", fontsize=13, fontweight="bold", pad=12)
|
||||
ax.set_xlim(-0.3, 12); ax.set_ylim(-0.3, 12)
|
||||
ax.set_aspect("equal"); ax.grid(True, alpha=0.12)
|
||||
ax.set_xlim(-0.3, 12)
|
||||
ax.set_ylim(-0.3, 12)
|
||||
ax.set_aspect("equal")
|
||||
ax.grid(True, alpha=0.12)
|
||||
_save(fig, "robustness_clean_vs_other.png")
|
||||
|
||||
|
||||
@@ -236,7 +248,8 @@ def fig_per_talk():
|
||||
talks = DATA["acl6060"]["talks"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9, 5))
|
||||
x = np.arange(len(talks)); w = 0.35
|
||||
x = np.arange(len(talks))
|
||||
w = 0.35
|
||||
|
||||
bars_v = ax.bar(x - w/2, [v[t] for t in talks], w, color=COLORS["voxtral"],
|
||||
edgecolor="white", label="Voxtral 4B (vLLM)")
|
||||
@@ -254,8 +267,10 @@ def fig_per_talk():
|
||||
ax.set_ylabel("WER %")
|
||||
ax.set_title("Per-Talk WER — ACL6060 Conference Talks (H100 80 GB)",
|
||||
fontsize=13, fontweight="bold", pad=12)
|
||||
ax.set_xticks(x); ax.set_xticklabels([f"Talk {t}" for t in talks])
|
||||
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels([f"Talk {t}" for t in talks])
|
||||
ax.legend(fontsize=9)
|
||||
ax.grid(axis="y", alpha=0.15)
|
||||
ax.set_ylim(0, 18)
|
||||
_save(fig, "acl6060_per_talk.png")
|
||||
|
||||
|
||||
5524
benchmarks/m5/bench_0.6b_simul_500.json
Normal file
5524
benchmarks/m5/bench_0.6b_simul_500.json
Normal file
File diff suppressed because it is too large
Load Diff
5524
benchmarks/m5/bench_1.7b_simul_500.json
Normal file
5524
benchmarks/m5/bench_1.7b_simul_500.json
Normal file
File diff suppressed because it is too large
Load Diff
141
benchmarks/m5/generate_figures.py
Normal file
141
benchmarks/m5/generate_figures.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate combined M5 vs H100 benchmark figure for WhisperLiveKit.
|
||||
|
||||
Produces a WER vs RTF scatter plot comparing Apple M5 (MLX) and
|
||||
NVIDIA H100 results on LibriSpeech test-clean.
|
||||
|
||||
Note: M5 uses per-utterance evaluation (500 samples), while H100
|
||||
uses chapter-grouped evaluation (91 chapters). Per-utterance WER
|
||||
is typically lower because short utterances avoid long-range errors.
|
||||
|
||||
Run: python3 benchmarks/m5/generate_figures.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
H100_DATA = json.load(open(os.path.join(DIR, "..", "h100", "results.json")))
|
||||
M5_DATA = json.load(open(os.path.join(DIR, "results.json")))
|
||||
|
||||
# -- Style --
|
||||
plt.rcParams.update({
|
||||
"font.family": "sans-serif",
|
||||
"font.size": 11,
|
||||
"axes.spines.top": False,
|
||||
"axes.spines.right": False,
|
||||
})
|
||||
|
||||
COLORS = {
|
||||
"whisper": "#d63031",
|
||||
"qwen_b": "#6c5ce7",
|
||||
"qwen_s": "#00b894",
|
||||
"voxtral": "#fdcb6e",
|
||||
"m5_qwen": "#0984e3",
|
||||
}
|
||||
|
||||
|
||||
def _save(fig, name):
|
||||
path = os.path.join(DIR, name)
|
||||
fig.savefig(path, dpi=180, bbox_inches="tight", facecolor="white")
|
||||
plt.close(fig)
|
||||
print(f" saved: {name}")
|
||||
|
||||
|
||||
def fig_m5_vs_h100():
|
||||
"""WER vs RTF scatter: M5 (MLX) and H100 (CUDA) on LibriSpeech test-clean."""
|
||||
h100 = H100_DATA["librispeech_clean"]["systems"]
|
||||
m5 = M5_DATA["models"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 7))
|
||||
|
||||
# Light green band for "good WER" zone
|
||||
ax.axhspan(0, 5, color="#f0fff0", alpha=0.5, zorder=0)
|
||||
|
||||
# --- H100 points ---
|
||||
h100_pts = [
|
||||
("Whisper large-v3\n(H100, batch)", h100["whisper_large_v3_batch"], COLORS["whisper"], "h", 220),
|
||||
("Qwen3 0.6B batch\n(H100)", h100["qwen3_0.6b_batch"], COLORS["qwen_b"], "h", 170),
|
||||
("Qwen3 1.7B batch\n(H100)", h100["qwen3_1.7b_batch"], COLORS["qwen_b"], "h", 220),
|
||||
("Voxtral 4B vLLM\n(H100)", h100["voxtral_4b_vllm_realtime"], COLORS["voxtral"], "D", 240),
|
||||
("Qwen3 0.6B SimulStream+KV\n(H100)", h100["qwen3_0.6b_simulstream_kv"], COLORS["qwen_s"], "s", 200),
|
||||
("Qwen3 1.7B SimulStream+KV\n(H100)", h100["qwen3_1.7b_simulstream_kv"], COLORS["qwen_s"], "s", 260),
|
||||
]
|
||||
h100_offsets = [(-55, 10), (-55, -22), (8, -18), (8, 10), (8, 10), (8, -18)]
|
||||
|
||||
for (name, d, color, marker, sz), (lx, ly) in zip(h100_pts, h100_offsets):
|
||||
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
|
||||
edgecolors="white", linewidths=1.5, zorder=5)
|
||||
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
|
||||
xytext=(lx, ly), textcoords="offset points",
|
||||
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
|
||||
|
||||
# --- M5 points ---
|
||||
m5_pts = [
|
||||
("Qwen3 0.6B SimulStream\n(M5, MLX)", m5["qwen3-asr-0.6b-simul"], COLORS["m5_qwen"], "^", 260),
|
||||
("Qwen3 1.7B SimulStream\n(M5, MLX)", m5["qwen3-asr-1.7b-simul"], COLORS["m5_qwen"], "^", 300),
|
||||
]
|
||||
m5_offsets = [(8, 8), (8, -18)]
|
||||
|
||||
for (name, d, color, marker, sz), (lx, ly) in zip(m5_pts, m5_offsets):
|
||||
ax.scatter(d["rtf"], d["wer"], s=sz, c=color, marker=marker,
|
||||
edgecolors="white", linewidths=1.5, zorder=6)
|
||||
ax.annotate(name, (d["rtf"], d["wer"]), fontsize=7.5, fontweight="bold",
|
||||
xytext=(lx, ly), textcoords="offset points",
|
||||
arrowprops=dict(arrowstyle="-", color="#aaa", lw=0.5))
|
||||
|
||||
# --- Connecting lines between same models on different hardware ---
|
||||
# 0.6B: H100 SimulStream+KV -> M5 SimulStream
|
||||
ax.plot([h100["qwen3_0.6b_simulstream_kv"]["rtf"], m5["qwen3-asr-0.6b-simul"]["rtf"]],
|
||||
[h100["qwen3_0.6b_simulstream_kv"]["wer"], m5["qwen3-asr-0.6b-simul"]["wer"]],
|
||||
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
|
||||
# 1.7B: H100 SimulStream+KV -> M5 SimulStream
|
||||
ax.plot([h100["qwen3_1.7b_simulstream_kv"]["rtf"], m5["qwen3-asr-1.7b-simul"]["rtf"]],
|
||||
[h100["qwen3_1.7b_simulstream_kv"]["wer"], m5["qwen3-asr-1.7b-simul"]["wer"]],
|
||||
"--", color="#0984e3", alpha=0.3, lw=1.5, zorder=3)
|
||||
|
||||
# --- RTF = 1 line (real-time boundary) ---
|
||||
ax.axvline(x=1.0, color="#e17055", linestyle=":", alpha=0.5, lw=1.5, zorder=1)
|
||||
ax.text(1.02, 0.5, "real-time\nboundary", fontsize=8, color="#e17055",
|
||||
fontstyle="italic", alpha=0.7, va="bottom")
|
||||
|
||||
# --- Methodology note ---
|
||||
ax.text(0.98, 0.02,
|
||||
"H100: chapter-grouped WER (91 chapters) | M5: per-utterance WER (500 samples)\n"
|
||||
"Per-utterance WER is typically lower -- results are not directly comparable.",
|
||||
transform=ax.transAxes, fontsize=7.5, color="#666",
|
||||
ha="right", va="bottom", fontstyle="italic",
|
||||
bbox=dict(boxstyle="round,pad=0.3", fc="#fff9e6", ec="#ddd", alpha=0.9))
|
||||
|
||||
ax.set_xlabel("RTF (lower = faster)")
|
||||
ax.set_ylabel("WER % (lower = better)")
|
||||
ax.set_title("H100 vs M5 (MLX) -- Qwen3-ASR on LibriSpeech test-clean",
|
||||
fontsize=13, fontweight="bold", pad=12)
|
||||
ax.set_xlim(-0.01, 1.1)
|
||||
ax.set_ylim(-0.5, 10)
|
||||
ax.grid(True, alpha=0.12)
|
||||
|
||||
legend = [
|
||||
mpatches.Patch(color=COLORS["whisper"], label="Whisper large-v3 (H100)"),
|
||||
mpatches.Patch(color=COLORS["qwen_b"], label="Qwen3-ASR batch (H100)"),
|
||||
mpatches.Patch(color=COLORS["qwen_s"], label="Qwen3 SimulStream+KV (H100)"),
|
||||
mpatches.Patch(color=COLORS["voxtral"], label="Voxtral 4B vLLM (H100)"),
|
||||
mpatches.Patch(color=COLORS["m5_qwen"], label="Qwen3 SimulStream (M5, MLX)"),
|
||||
plt.Line2D([0], [0], marker="h", color="w", mfc="gray", ms=8, label="Batch mode"),
|
||||
plt.Line2D([0], [0], marker="s", color="w", mfc="gray", ms=8, label="Streaming (H100)"),
|
||||
plt.Line2D([0], [0], marker="^", color="w", mfc="gray", ms=8, label="Streaming (M5)"),
|
||||
]
|
||||
ax.legend(handles=legend, fontsize=8, loc="upper right", framealpha=0.85, ncol=2)
|
||||
_save(fig, "m5_vs_h100_wer_rtf.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating M5 vs H100 benchmark figure...")
|
||||
fig_m5_vs_h100()
|
||||
print("Done!")
|
||||
BIN
benchmarks/m5/m5_vs_h100_wer_rtf.png
Normal file
BIN
benchmarks/m5/m5_vs_h100_wer_rtf.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 162 KiB |
9
benchmarks/m5/results.json
Normal file
9
benchmarks/m5/results.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"platform": "Apple M5 (32GB RAM, MLX fp16)",
|
||||
"dataset": "LibriSpeech test-clean",
|
||||
"methodology": "per-utterance (500 samples)",
|
||||
"models": {
|
||||
"qwen3-asr-0.6b-simul": {"wer": 3.30, "rtf": 0.263},
|
||||
"qwen3-asr-1.7b-simul": {"wer": 4.07, "rtf": 0.944}
|
||||
}
|
||||
}
|
||||
@@ -418,7 +418,31 @@ class AudioProcessor:
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def _update_diarization_state(self, diarization_segments) -> None:
|
||||
"""Push new diarization segments into the shared state."""
|
||||
if not diarization_segments:
|
||||
return
|
||||
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
||||
async with self.lock:
|
||||
self.state.new_diarization.extend(diarization_segments)
|
||||
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
||||
|
||||
async def _drain_diarization_buffer(self) -> None:
|
||||
"""Process all remaining audio in the diarization buffer.
|
||||
|
||||
Sortformer-style backends accumulate audio in an internal buffer and
|
||||
process one chunk per ``diarize()`` call, returning ``[]`` when the
|
||||
buffer is too short. This helper loops until the buffer is fully
|
||||
consumed.
|
||||
"""
|
||||
while True:
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
if not diarization_segments:
|
||||
break
|
||||
await self._update_diarization_state(diarization_segments)
|
||||
|
||||
async def diarization_processor(self) -> None:
|
||||
has_buffer = hasattr(self.diarization, 'buffer_audio')
|
||||
while True:
|
||||
try:
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
@@ -429,16 +453,26 @@ class AudioProcessor:
|
||||
self.diarization.insert_silence(item.duration)
|
||||
continue
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
diar_end = 0.0
|
||||
if diarization_segments:
|
||||
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
||||
async with self.lock:
|
||||
self.state.new_diarization = diarization_segments
|
||||
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
||||
if has_buffer:
|
||||
await self._drain_diarization_buffer()
|
||||
else:
|
||||
# Cumulative backends (e.g. Diart): replace, not extend
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
diar_end = 0.0
|
||||
if diarization_segments:
|
||||
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
||||
async with self.lock:
|
||||
self.state.new_diarization = diarization_segments
|
||||
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
# Drain any remaining audio in the buffer before exiting
|
||||
if has_buffer:
|
||||
try:
|
||||
await self._drain_diarization_buffer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception draining diarization buffer: {e}")
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self) -> None:
|
||||
|
||||
@@ -59,6 +59,7 @@ def detect_available_backends() -> List[str]:
|
||||
|
||||
try:
|
||||
import mlx.core # noqa: F401
|
||||
|
||||
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model # noqa: F401
|
||||
backends.append("voxtral-mlx")
|
||||
except ImportError:
|
||||
|
||||
@@ -233,6 +233,7 @@ def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = 16000) -> None:
|
||||
|
||||
def _decode_audio(audio_bytes: bytes) -> tuple:
|
||||
import io
|
||||
|
||||
import soundfile as sf
|
||||
audio_array, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
|
||||
return np.array(audio_array, dtype=np.float32), sr
|
||||
|
||||
@@ -103,7 +103,6 @@ def print_report(report: BenchmarkReport, out: TextIO = sys.stderr) -> None:
|
||||
|
||||
# Per-language breakdown
|
||||
wer_by_lang = report.wer_by_language()
|
||||
rtf_by_lang = report.rtf_by_language()
|
||||
if len(wer_by_lang) > 1:
|
||||
w(f"\n {BOLD}By Language{RESET}\n")
|
||||
w(f" {'─' * 40}\n")
|
||||
|
||||
@@ -46,7 +46,6 @@ class BenchmarkRunner:
|
||||
async def run(self) -> BenchmarkReport:
|
||||
"""Run the full benchmark suite and return a report."""
|
||||
from whisperlivekit.metrics import compute_wer
|
||||
from whisperlivekit.test_harness import TestHarness
|
||||
|
||||
# Get samples
|
||||
samples = get_benchmark_samples(
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
"""
|
||||
Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline.
|
||||
|
||||
Converts streaming ASRToken output from SimulStreaming into the JSONL
|
||||
format expected by the AlignAtt MT agent (iwslt26-sst).
|
||||
|
||||
Output format (one JSON per line):
|
||||
{"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0}
|
||||
|
||||
Where:
|
||||
- text: the emitted word/phrase
|
||||
- emission_time: wall-clock time when the word was emitted (for compute-aware eval)
|
||||
- speech_time: timestamp in the audio (for compute-unaware eval)
|
||||
- is_final: whether this is the last word of a segment/silence boundary
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, TextIO
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
|
||||
|
||||
class CascadeBridge:
|
||||
"""Converts ASRToken stream to JSONL for the MT agent."""
|
||||
|
||||
def __init__(self, output_file: TextIO = None):
|
||||
self.output_file = output_file
|
||||
self.start_time = time.time()
|
||||
self.entries: List[dict] = []
|
||||
|
||||
def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False):
|
||||
"""Emit a batch of tokens from the STT."""
|
||||
wall_clock = time.time() - self.start_time
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
entry = {
|
||||
"text": token.text.strip(),
|
||||
"emission_time": round(wall_clock, 3),
|
||||
"speech_time": round(token.start, 3),
|
||||
"is_final": is_final and (i == len(tokens) - 1),
|
||||
}
|
||||
self.entries.append(entry)
|
||||
if self.output_file:
|
||||
self.output_file.write(json.dumps(entry) + "\n")
|
||||
self.output_file.flush()
|
||||
|
||||
def get_entries(self) -> List[dict]:
|
||||
return self.entries
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""Get the full transcribed text."""
|
||||
return " ".join(e["text"] for e in self.entries if e["text"])
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save all entries to a JSONL file."""
|
||||
with open(path, "w") as f:
|
||||
for entry in self.entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
def run_stt_to_jsonl(
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
model_id: str = "Qwen/Qwen3-ASR-0.6B",
|
||||
alignment_heads_path: str = None,
|
||||
border_fraction: float = 0.20,
|
||||
language: str = "en",
|
||||
chunk_sec: float = 1.0,
|
||||
):
|
||||
"""Run STT on an audio file and save JSONL output for the MT agent.
|
||||
|
||||
This is the main entry point for the cascade: audio file → JSONL.
|
||||
"""
|
||||
import wave
|
||||
import numpy as np
|
||||
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor
|
||||
|
||||
# Load audio
|
||||
with wave.open(audio_path, 'r') as wf:
|
||||
audio = np.frombuffer(
|
||||
wf.readframes(wf.getnframes()), dtype=np.int16
|
||||
).astype(np.float32) / 32768.0
|
||||
|
||||
# Initialize STT
|
||||
asr = Qwen3SimulKVASR(
|
||||
model_dir=model_id,
|
||||
lan=language,
|
||||
alignment_heads_path=alignment_heads_path,
|
||||
border_fraction=border_fraction,
|
||||
)
|
||||
proc = Qwen3SimulKVOnlineProcessor(asr)
|
||||
bridge = CascadeBridge()
|
||||
|
||||
# Stream audio in chunks
|
||||
chunk_samples = int(chunk_sec * 16000)
|
||||
offset = 0
|
||||
stream_time = 0.0
|
||||
|
||||
while offset < len(audio):
|
||||
chunk = audio[offset:offset + chunk_samples]
|
||||
stream_time += len(chunk) / 16000
|
||||
proc.insert_audio_chunk(chunk, stream_time)
|
||||
words, _ = proc.process_iter(is_last=False)
|
||||
if words:
|
||||
bridge.emit_tokens(words, is_final=False)
|
||||
offset += chunk_samples
|
||||
|
||||
# Final flush
|
||||
final_words, _ = proc.finish()
|
||||
if final_words:
|
||||
bridge.emit_tokens(final_words, is_final=True)
|
||||
|
||||
# Save
|
||||
bridge.save(output_path)
|
||||
return bridge
|
||||
@@ -386,7 +386,8 @@ def cmd_models():
|
||||
# --- System info ---
|
||||
print(f"\n Platform: {platform.system()} {platform.machine()}")
|
||||
print(f" Accelerator: {_gpu_info()}")
|
||||
print(f" ffmpeg: {'found' if _check_ffmpeg() else '\033[31mNOT FOUND\033[0m (required)'}")
|
||||
_ffmpeg_status = "found" if _check_ffmpeg() else "\033[31mNOT FOUND\033[0m (required)"
|
||||
print(f" ffmpeg: {_ffmpeg_status}")
|
||||
|
||||
# --- Model catalog ---
|
||||
print("\n Models:\n")
|
||||
@@ -419,7 +420,7 @@ def cmd_models():
|
||||
)
|
||||
|
||||
# --- Quick start ---
|
||||
print(f"\n Quick start:\n")
|
||||
print("\n Quick start:\n")
|
||||
if is_apple_silicon:
|
||||
print(" wlk run voxtral-mlx # Best streaming on Apple Silicon")
|
||||
print(" wlk run large-v3-turbo # Best quality/speed balance")
|
||||
@@ -806,7 +807,7 @@ async def _run_bench_new(parsed, languages, categories):
|
||||
on_progress=on_progress,
|
||||
)
|
||||
|
||||
print(f"\n Downloading benchmark samples (cached after first run)...",
|
||||
print("\n Downloading benchmark samples (cached after first run)...",
|
||||
file=sys.stderr)
|
||||
|
||||
report = await runner.run()
|
||||
|
||||
@@ -121,6 +121,15 @@ class TranscriptionEngine:
|
||||
self.tokenizer = None
|
||||
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||
elif config.backend == "qwen3-mlx-simul":
|
||||
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingASR
|
||||
self.tokenizer = None
|
||||
self.asr = Qwen3MLXSimulStreamingASR(
|
||||
**transcription_common_params,
|
||||
alignment_heads_path=config.custom_alignment_heads,
|
||||
border_fraction=getattr(config, 'border_fraction', 0.15),
|
||||
)
|
||||
logger.info("Using Qwen3 MLX SimulStreaming backend")
|
||||
elif config.backend == "qwen3-mlx":
|
||||
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXASR
|
||||
self.tokenizer = None
|
||||
@@ -247,6 +256,9 @@ def online_factory(args, asr, language=None):
|
||||
if backend == "qwen3-simul-kv":
|
||||
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
|
||||
return Qwen3SimulKVOnlineProcessor(asr)
|
||||
if backend == "qwen3-mlx-simul":
|
||||
from whisperlivekit.qwen3_mlx_simul import Qwen3MLXSimulStreamingOnlineProcessor
|
||||
return Qwen3MLXSimulStreamingOnlineProcessor(asr)
|
||||
if backend == "qwen3-mlx":
|
||||
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
|
||||
return Qwen3MLXOnlineProcessor(asr)
|
||||
|
||||
@@ -147,8 +147,8 @@ def parse_args():
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-simul", "vllm-realtime"],
|
||||
help="Select the ASR backend implementation. Use 'qwen3' for Qwen3-ASR with LocalAgreement. Use 'qwen3-mlx' for Qwen3-ASR on Apple Silicon (MLX). Use 'qwen3-simul' for Qwen3-ASR with SimulStreaming (requires alignment heads). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx", "qwen3", "qwen3-mlx", "qwen3-mlx-simul", "qwen3-simul", "vllm-realtime"],
|
||||
help="Select the ASR backend implementation. Use 'qwen3-mlx-simul' for Qwen3-ASR SimulStreaming on Apple Silicon (MLX). Use 'qwen3-mlx' for Qwen3-ASR LocalAgreement on MLX. Use 'qwen3-simul' for Qwen3-ASR SimulStreaming (PyTorch). Use 'vllm-realtime' for vLLM Realtime WebSocket.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
|
||||
759
whisperlivekit/qwen3_mlx_simul.py
Normal file
759
whisperlivekit/qwen3_mlx_simul.py
Normal file
@@ -0,0 +1,759 @@
|
||||
"""
|
||||
Qwen3-ASR SimulStreaming (AlignAtt) on MLX for Apple Silicon.
|
||||
|
||||
Uses the ``mlx_qwen3_asr`` library for model loading, audio encoding, and
|
||||
tokenization. Implements the AlignAtt border-distance policy by monkey-
|
||||
patching ``TextAttention.__call__`` on alignment layers to capture Q (with
|
||||
RoPE) during autoregressive decode steps, then computing ``Q @ K_audio^T``
|
||||
from the KV cache to find the most-attended audio frame.
|
||||
|
||||
This is the MLX equivalent of ``qwen3_simul.py`` (PyTorch) which uses
|
||||
``register_forward_hook`` for the same purpose.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_RATE = 16_000
|
||||
|
||||
# Model size aliases (same as qwen3_mlx_asr.py)
|
||||
QWEN3_MLX_MODEL_MAPPING = {
|
||||
"base": "Qwen/Qwen3-ASR-0.6B",
|
||||
"tiny": "Qwen/Qwen3-ASR-0.6B",
|
||||
"small": "Qwen/Qwen3-ASR-0.6B",
|
||||
"large": "Qwen/Qwen3-ASR-1.7B",
|
||||
"medium": "Qwen/Qwen3-ASR-1.7B",
|
||||
"large-v3": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
"1.7b": "Qwen/Qwen3-ASR-1.7B",
|
||||
"0.6b": "Qwen/Qwen3-ASR-0.6B",
|
||||
}
|
||||
|
||||
# Whisper language codes -> Qwen3 canonical language names
|
||||
WHISPER_TO_QWEN3_LANGUAGE = {
|
||||
"zh": "Chinese", "en": "English", "yue": "Cantonese",
|
||||
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
|
||||
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
|
||||
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
|
||||
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
|
||||
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
|
||||
"pl": "Polish", "cs": "Czech", "fa": "Persian",
|
||||
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
|
||||
}
|
||||
|
||||
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3MLXSimulConfig:
|
||||
language: str = "auto"
|
||||
alignment_heads_path: Optional[str] = None
|
||||
border_fraction: float = 0.15
|
||||
rewind_fraction: float = 0.12
|
||||
audio_min_len: float = 3.0
|
||||
audio_max_len: float = 15.0
|
||||
max_context_tokens: int = 30
|
||||
max_alignment_heads: int = 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-session state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionState:
|
||||
audio_buffer: np.ndarray = field(
|
||||
default_factory=lambda: np.array([], dtype=np.float32)
|
||||
)
|
||||
cumulative_time_offset: float = 0.0
|
||||
global_time_offset: float = 0.0
|
||||
speaker: int = -1
|
||||
|
||||
last_attend_frame: int = -15
|
||||
committed_word_count: int = 0
|
||||
committed_token_ids: List[int] = field(default_factory=list)
|
||||
detected_language: Optional[str] = None
|
||||
last_infer_samples: int = 0
|
||||
# Pending partial word from previous _infer() call.
|
||||
# When a border stops mid-word (e.g., "Vill" from "Villard"),
|
||||
# the partial is held here and prepended to the next call's output.
|
||||
pending_partial: str = ""
|
||||
pending_partial_start: Optional[float] = None
|
||||
# Whether the first emitted token of this call is a continuation of the
|
||||
# previous call's last word (no leading space → subword continuation).
|
||||
first_emit_is_continuation: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared model holder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3MLXSimulStreamingASR:
|
||||
"""Loads the Qwen3-ASR model via ``mlx_qwen3_asr`` once and keeps it
|
||||
alive for the lifetime of the server. Shared across sessions."""
|
||||
|
||||
sep = ""
|
||||
SAMPLING_RATE = SAMPLE_RATE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = None,
|
||||
model_dir: str = None,
|
||||
model_path: str = None,
|
||||
lan: str = "auto",
|
||||
alignment_heads_path: Optional[str] = None,
|
||||
border_fraction: float = 0.15,
|
||||
warmup_file: Optional[str] = None,
|
||||
model_cache_dir: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
min_chunk_size: float = 0.1,
|
||||
direct_english_translation: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
import mlx.core as mx
|
||||
import mlx_qwen3_asr
|
||||
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
self.warmup_file = warmup_file
|
||||
|
||||
self.cfg = Qwen3MLXSimulConfig(
|
||||
language=lan,
|
||||
alignment_heads_path=alignment_heads_path,
|
||||
border_fraction=border_fraction,
|
||||
)
|
||||
|
||||
# Resolve model path
|
||||
resolved = model_dir or model_path
|
||||
if not resolved:
|
||||
size = (model_size or "base").lower()
|
||||
if "/" in size or size.startswith("."):
|
||||
resolved = size
|
||||
else:
|
||||
resolved = QWEN3_MLX_MODEL_MAPPING.get(size, "Qwen/Qwen3-ASR-0.6B")
|
||||
|
||||
t0 = time.time()
|
||||
logger.info("Loading Qwen3-ASR MLX model '%s' for SimulStreaming ...", resolved)
|
||||
self.model, self._config = mlx_qwen3_asr.load_model(resolved, dtype=mx.float16)
|
||||
logger.info("Model loaded in %.2fs", time.time() - t0)
|
||||
|
||||
# Tokenizer
|
||||
tok_path = getattr(self.model, "_resolved_model_path", None) or resolved
|
||||
self.tokenizer = mlx_qwen3_asr.tokenizer.Tokenizer(str(tok_path))
|
||||
|
||||
# Architecture info
|
||||
text_cfg = self._config.text_config
|
||||
self.num_layers = text_cfg.num_hidden_layers
|
||||
self.num_heads = text_cfg.num_attention_heads
|
||||
self.num_kv_heads = text_cfg.num_key_value_heads
|
||||
self.head_dim = text_cfg.head_dim
|
||||
self.gqa_ratio = self.num_heads // self.num_kv_heads
|
||||
self.audio_token_id = self._config.audio_token_id
|
||||
|
||||
logger.info(
|
||||
"Qwen3-ASR arch: %d layers x %d heads (%d kv), head_dim=%d, GQA=%d",
|
||||
self.num_layers, self.num_heads, self.num_kv_heads,
|
||||
self.head_dim, self.gqa_ratio,
|
||||
)
|
||||
|
||||
# Alignment heads
|
||||
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
|
||||
self.heads_by_layer = {}
|
||||
for layer_idx, head_idx in self.alignment_heads:
|
||||
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
|
||||
|
||||
self.backend_choice = "qwen3-mlx-simul"
|
||||
|
||||
# Warmup
|
||||
if warmup_file:
|
||||
from whisperlivekit.warmup import load_file
|
||||
audio = load_file(warmup_file)
|
||||
if audio is not None:
|
||||
self._warmup(audio)
|
||||
|
||||
def _load_alignment_heads(
|
||||
self, path: Optional[str],
|
||||
) -> List[Tuple[int, int]]:
|
||||
max_heads = self.cfg.max_alignment_heads
|
||||
|
||||
if path and Path(path).exists():
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
|
||||
heads = all_heads[:max_heads]
|
||||
logger.info(
|
||||
"Loaded top %d alignment heads from %s (of %d total)",
|
||||
len(heads), path, len(all_heads),
|
||||
)
|
||||
return heads
|
||||
|
||||
# Default heuristic: last quarter of layers, all heads
|
||||
default_heads = []
|
||||
start_layer = self.num_layers * 3 // 4
|
||||
for layer in range(start_layer, self.num_layers):
|
||||
for head in range(self.num_heads):
|
||||
default_heads.append((layer, head))
|
||||
logger.warning(
|
||||
"No alignment heads file. Using default heuristic: "
|
||||
"%d heads from layers %d-%d.",
|
||||
len(default_heads), start_layer, self.num_layers - 1,
|
||||
)
|
||||
return default_heads[:max_heads]
|
||||
|
||||
def _warmup(self, audio: np.ndarray):
|
||||
import mlx.core as mx
|
||||
try:
|
||||
from mlx_qwen3_asr.audio import compute_features
|
||||
audio = audio[:SAMPLE_RATE * 2]
|
||||
mel, feat_lens = compute_features(audio)
|
||||
mel = mel.astype(mx.float16)
|
||||
audio_features, _ = self.model.audio_tower(mel, feat_lens)
|
||||
n_audio = int(audio_features.shape[1])
|
||||
prompt = self.tokenizer.build_prompt_tokens(n_audio, language="English")
|
||||
input_ids = mx.array([prompt])
|
||||
positions = mx.arange(input_ids.shape[1])[None, :]
|
||||
position_ids = mx.stack([positions, positions, positions], axis=1)
|
||||
cache = self.model.create_cache()
|
||||
logits = self.model.prefill(input_ids, audio_features, position_ids, cache)
|
||||
mx.eval(logits)
|
||||
logger.info("Qwen3 MLX SimulStreaming warmup complete")
|
||||
except Exception as e:
|
||||
logger.warning("Warmup failed: %s", e)
|
||||
|
||||
def transcribe(self, audio):
|
||||
pass # all work in the online processor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention capture via wrapper replacement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AttnCaptureWrapper:
|
||||
"""Wraps a TextAttention module to capture alignment scores during decode.
|
||||
|
||||
Replaces ``layer.self_attn`` with this wrapper. On decode steps (L=1),
|
||||
recomputes Q with RoPE, reads cached K from the audio region, computes
|
||||
``Q @ K_audio^T`` for alignment heads, and stores the argmax frame in
|
||||
``capture["step_frames"]``.
|
||||
|
||||
Python dunder resolution (``__call__``) goes through the *class*, not the
|
||||
instance, so monkey-patching ``attn.__call__`` on an ``nn.Module`` does
|
||||
not work. This wrapper class defines its own ``__call__`` and delegates
|
||||
everything else to the wrapped module via ``__getattr__``.
|
||||
"""
|
||||
|
||||
def __init__(self, original, layer_idx, head_indices, gqa_ratio,
|
||||
audio_start, audio_end, capture):
|
||||
# Store in __dict__ directly to avoid triggering __getattr__
|
||||
self.__dict__["_original"] = original
|
||||
self.__dict__["_layer_idx"] = layer_idx
|
||||
self.__dict__["_head_indices"] = head_indices
|
||||
self.__dict__["_gqa_ratio"] = gqa_ratio
|
||||
self.__dict__["_audio_start"] = audio_start
|
||||
self.__dict__["_audio_end"] = audio_end
|
||||
self.__dict__["_capture"] = capture
|
||||
|
||||
def __call__(self, x, cos, sin, mask=None, cache=None, layer_idx=0):
|
||||
import mlx.core as mx
|
||||
from mlx_qwen3_asr.mrope import apply_rotary_pos_emb
|
||||
|
||||
orig = self.__dict__["_original"]
|
||||
B, L, _ = x.shape
|
||||
|
||||
if L == 1 and cache is not None:
|
||||
li = self.__dict__["_layer_idx"]
|
||||
h_indices = self.__dict__["_head_indices"]
|
||||
gqa = self.__dict__["_gqa_ratio"]
|
||||
a_start = self.__dict__["_audio_start"]
|
||||
a_end = self.__dict__["_audio_end"]
|
||||
cap = self.__dict__["_capture"]
|
||||
|
||||
# Recompute Q with RoPE (cheap: single token)
|
||||
q = orig.q_proj(x)
|
||||
q = q.reshape(B, L, orig.num_heads, orig.head_dim)
|
||||
q = orig.q_norm(q)
|
||||
q = q.transpose(0, 2, 1, 3) # (B, H, 1, D)
|
||||
q_rope, _ = apply_rotary_pos_emb(q, q, cos, sin)
|
||||
|
||||
# K from cache (already has RoPE baked in from cache.update)
|
||||
k_cached = cache.keys[li]
|
||||
if k_cached is not None and a_end <= k_cached.shape[2]:
|
||||
for h_idx in h_indices:
|
||||
kv_h = h_idx // gqa
|
||||
q_h = q_rope[0, h_idx, 0] # (head_dim,)
|
||||
k_audio = k_cached[0, kv_h, a_start:a_end] # (n_audio, D)
|
||||
scores = k_audio @ q_h # (n_audio,)
|
||||
frame = int(mx.argmax(scores).item())
|
||||
cap["step_frames"].append(frame)
|
||||
|
||||
return orig(x, cos, sin, mask=mask, cache=cache, layer_idx=layer_idx)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.__dict__["_original"], name)
|
||||
|
||||
|
||||
def _install_alignment_hooks(model, heads_by_layer, gqa_ratio, audio_start, audio_end, capture):
|
||||
"""Replace ``self_attn`` on alignment layers with capture wrappers.
|
||||
|
||||
Returns a list of ``(layer_idx, original_attn)`` for later restoration.
|
||||
"""
|
||||
originals = []
|
||||
for layer_idx, head_indices in heads_by_layer.items():
|
||||
if layer_idx >= len(model.model.layers):
|
||||
continue
|
||||
layer = model.model.layers[layer_idx]
|
||||
orig_attn = layer.self_attn
|
||||
wrapper = _AttnCaptureWrapper(
|
||||
orig_attn, layer_idx, head_indices, gqa_ratio,
|
||||
audio_start, audio_end, capture,
|
||||
)
|
||||
layer.self_attn = wrapper
|
||||
originals.append((layer_idx, orig_attn))
|
||||
return originals
|
||||
|
||||
|
||||
def _remove_alignment_hooks(model, originals):
|
||||
"""Restore original self_attn modules."""
|
||||
for layer_idx, orig_attn in originals:
|
||||
model.model.layers[layer_idx].self_attn = orig_attn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-session online processor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3MLXSimulStreamingOnlineProcessor:
|
||||
"""Per-session processor implementing AlignAtt on MLX.
|
||||
|
||||
Same interface as other online processors:
|
||||
insert_audio_chunk / process_iter / get_buffer / start_silence /
|
||||
end_silence / finish / warmup / new_speaker.
|
||||
"""
|
||||
|
||||
SAMPLING_RATE = SAMPLE_RATE
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
def __init__(self, asr: Qwen3MLXSimulStreamingASR, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer: List[ASRToken] = []
|
||||
self.state = _SessionState()
|
||||
|
||||
# -- properties expected by AudioProcessor --
|
||||
|
||||
@property
|
||||
def speaker(self):
|
||||
return self.state.speaker
|
||||
|
||||
@speaker.setter
|
||||
def speaker(self, value):
|
||||
self.state.speaker = value
|
||||
|
||||
@property
|
||||
def global_time_offset(self):
|
||||
return self.state.global_time_offset
|
||||
|
||||
@global_time_offset.setter
|
||||
def global_time_offset(self, value):
|
||||
self.state.global_time_offset = value
|
||||
|
||||
# -- audio ingestion --
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
self.end = audio_stream_end_time
|
||||
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
|
||||
|
||||
# Trim if too long
|
||||
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
|
||||
if len(self.state.audio_buffer) > max_samples:
|
||||
trim = len(self.state.audio_buffer) - max_samples
|
||||
self.state.audio_buffer = self.state.audio_buffer[trim:]
|
||||
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
|
||||
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
|
||||
|
||||
# -- main processing --
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
|
||||
if audio_duration < self.asr.cfg.audio_min_len:
|
||||
return [], self.end
|
||||
|
||||
# Throttle: at least 1s of new audio
|
||||
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
|
||||
if not is_last and new_samples < int(1.0 * self.SAMPLING_RATE):
|
||||
return [], self.end
|
||||
|
||||
try:
|
||||
words = self._infer(is_last)
|
||||
except Exception as e:
|
||||
logger.exception("Qwen3 MLX SimulStreaming inference error: %s", e)
|
||||
return [], self.end
|
||||
|
||||
# Update the budget marker after _infer() so the decoder can size its
|
||||
# generation budget using the real amount of fresh audio.
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
if not words:
|
||||
return [], self.end
|
||||
|
||||
self.buffer = []
|
||||
return words, self.end
|
||||
|
||||
def _infer(self, is_last: bool) -> List[ASRToken]:
|
||||
"""Run one inference cycle with alignment-head-based stopping."""
|
||||
import mlx.core as mx
|
||||
from mlx_qwen3_asr.audio import compute_features
|
||||
from mlx_qwen3_asr.generate import _detect_repetition
|
||||
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
model = asr.model
|
||||
|
||||
# 1. Encode audio
|
||||
mel, feat_lens = compute_features(state.audio_buffer)
|
||||
mel = mel.astype(mx.float16)
|
||||
audio_features, _ = model.audio_tower(mel, feat_lens)
|
||||
n_audio_tokens = int(audio_features.shape[1])
|
||||
mx.eval(audio_features)
|
||||
|
||||
if n_audio_tokens == 0:
|
||||
return []
|
||||
|
||||
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
|
||||
|
||||
# 2. Build prompt tokens
|
||||
lan = asr.cfg.language
|
||||
language = None
|
||||
if lan and lan != "auto":
|
||||
language = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
|
||||
|
||||
prompt_tokens = asr.tokenizer.build_prompt_tokens(
|
||||
n_audio_tokens=n_audio_tokens,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Append committed context tokens
|
||||
if state.committed_token_ids:
|
||||
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||
prompt_tokens.extend(ctx)
|
||||
|
||||
input_ids = mx.array([prompt_tokens])
|
||||
seq_len = input_ids.shape[1]
|
||||
|
||||
# 3. Find audio token range
|
||||
audio_positions = [
|
||||
i for i, t in enumerate(prompt_tokens) if t == asr.audio_token_id
|
||||
]
|
||||
if not audio_positions:
|
||||
return []
|
||||
audio_start = audio_positions[0]
|
||||
audio_end = audio_positions[-1] + 1
|
||||
|
||||
# 4. MRoPE position IDs
|
||||
positions = mx.arange(seq_len, dtype=mx.int32)[None, :]
|
||||
position_ids = mx.stack([positions, positions, positions], axis=1)
|
||||
|
||||
# 5. Prefill
|
||||
cache = model.create_cache(max_seq_len=seq_len + 120)
|
||||
logits = model.prefill(input_ids, audio_features, position_ids, cache)
|
||||
mx.eval(logits)
|
||||
|
||||
# 6. Install alignment hooks
|
||||
capture = {"step_frames": []}
|
||||
originals = _install_alignment_hooks(
|
||||
model, asr.heads_by_layer, asr.gqa_ratio,
|
||||
audio_start, audio_end, capture,
|
||||
)
|
||||
|
||||
# 7. Decode loop with border-distance policy
|
||||
eos_ids = set(asr.tokenizer.EOS_TOKEN_IDS)
|
||||
per_step_frames: List[List[int]] = []
|
||||
last_attend_frame = state.last_attend_frame
|
||||
border_stop_step: Optional[int] = None
|
||||
|
||||
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
|
||||
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
|
||||
|
||||
# Max tokens: ~6 tokens/sec of speech + margin
|
||||
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
|
||||
if is_last:
|
||||
max_tokens = min(int(audio_duration * 6) + 10, 120)
|
||||
else:
|
||||
max_tokens = min(int(max(new_audio_secs, 1.0) * 6) + 5, 40)
|
||||
|
||||
token = int(mx.argmax(logits.reshape(-1)).item())
|
||||
generated = [token]
|
||||
|
||||
try:
|
||||
for step in range(1, max_tokens):
|
||||
if token in eos_ids:
|
||||
break
|
||||
if _detect_repetition(generated):
|
||||
break
|
||||
|
||||
next_ids = mx.array([[token]])
|
||||
pos_val = seq_len + step - 1
|
||||
next_pos = mx.array([[[pos_val], [pos_val], [pos_val]]], dtype=mx.int32)
|
||||
logits = model.step(next_ids, next_pos, cache, validate_input_ids=False)
|
||||
mx.eval(logits)
|
||||
|
||||
token = int(mx.argmax(logits.reshape(-1)).item())
|
||||
generated.append(token)
|
||||
|
||||
# Collect frames from this step
|
||||
if capture["step_frames"]:
|
||||
per_step_frames.append(capture["step_frames"])
|
||||
capture["step_frames"] = []
|
||||
|
||||
# Border-distance check (skip first 3 steps)
|
||||
if (not is_last
|
||||
and border_stop_step is None
|
||||
and len(per_step_frames) >= 3):
|
||||
latest = per_step_frames[-1]
|
||||
if latest:
|
||||
frames_sorted = sorted(latest)
|
||||
attended = frames_sorted[len(frames_sorted) // 2]
|
||||
|
||||
# Rewind check
|
||||
if last_attend_frame - attended > rewind_threshold:
|
||||
border_stop_step = max(0, len(per_step_frames) - 2)
|
||||
break
|
||||
|
||||
last_attend_frame = attended
|
||||
|
||||
# Border check
|
||||
if (n_audio_tokens - attended) <= border_threshold:
|
||||
border_stop_step = len(per_step_frames) - 1
|
||||
break
|
||||
|
||||
# Periodic eval to prevent graph buildup
|
||||
if step % 8 == 0:
|
||||
mx.eval(cache.keys[-1])
|
||||
finally:
|
||||
_remove_alignment_hooks(model, originals)
|
||||
# Flush remaining frames
|
||||
if capture["step_frames"]:
|
||||
per_step_frames.append(capture["step_frames"])
|
||||
|
||||
state.last_attend_frame = last_attend_frame
|
||||
|
||||
# 8. Process generated tokens
|
||||
# Remove trailing EOS
|
||||
while generated and generated[-1] in eos_ids:
|
||||
generated.pop()
|
||||
|
||||
num_gen = len(generated)
|
||||
if num_gen == 0:
|
||||
return []
|
||||
|
||||
raw_text = asr.tokenizer.decode(generated)
|
||||
logger.info(
|
||||
"SimulStreaming raw: %d tokens (border_stop=%s), text=%r",
|
||||
num_gen, border_stop_step, raw_text[:100],
|
||||
)
|
||||
|
||||
# 9. Strip metadata prefix ("language English<asr_text>...")
|
||||
from mlx_qwen3_asr.tokenizer import parse_asr_output
|
||||
detected_lang, clean_text = parse_asr_output(
|
||||
raw_text,
|
||||
user_language=language,
|
||||
)
|
||||
|
||||
# Find how many tokens to skip for metadata
|
||||
metadata_offset = 0
|
||||
asr_text_tokens = asr.tokenizer.encode("<asr_text>")
|
||||
asr_text_id = asr_text_tokens[0] if asr_text_tokens else None
|
||||
if asr_text_id is not None:
|
||||
for i in range(min(num_gen, 10)):
|
||||
if generated[i] == asr_text_id:
|
||||
metadata_offset = i + 1
|
||||
break
|
||||
|
||||
if metadata_offset > 0:
|
||||
generated = generated[metadata_offset:]
|
||||
num_gen -= metadata_offset
|
||||
per_step_frames = per_step_frames[metadata_offset:]
|
||||
|
||||
if num_gen <= 0:
|
||||
return []
|
||||
|
||||
# Detect language
|
||||
if state.detected_language is None and detected_lang and detected_lang != "unknown":
|
||||
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
|
||||
detected_lang, detected_lang.lower(),
|
||||
)
|
||||
logger.info("Auto-detected language: %s", state.detected_language)
|
||||
|
||||
# 10. Determine how many tokens to emit
|
||||
step_frames = [f for f in per_step_frames if f]
|
||||
if border_stop_step is not None:
|
||||
emit_up_to = min(border_stop_step, num_gen)
|
||||
else:
|
||||
emit_up_to = num_gen
|
||||
|
||||
if emit_up_to <= 0:
|
||||
return []
|
||||
|
||||
emitted_ids = generated[:emit_up_to]
|
||||
|
||||
if emit_up_to <= 0:
|
||||
return []
|
||||
|
||||
# 11. Build timestamped words
|
||||
words = self._build_timestamped_words(
|
||||
emitted_ids, step_frames, emit_up_to,
|
||||
n_audio_tokens, audio_duration,
|
||||
)
|
||||
|
||||
# Update state
|
||||
state.committed_word_count += len(words)
|
||||
state.committed_token_ids.extend(emitted_ids)
|
||||
|
||||
return words
|
||||
|
||||
def _build_timestamped_words(
|
||||
self,
|
||||
generated_ids: List[int],
|
||||
step_frames: List[List[int]],
|
||||
emit_up_to: int,
|
||||
n_audio_tokens: int,
|
||||
audio_duration: float,
|
||||
) -> List[ASRToken]:
|
||||
"""Build timestamped ASRToken list from generated tokens and
|
||||
alignment-head captured frames."""
|
||||
state = self.state
|
||||
asr = self.asr
|
||||
|
||||
# Per-token attended frame (median of head votes)
|
||||
per_token_frame: List[Optional[int]] = []
|
||||
for step_idx in range(emit_up_to):
|
||||
if step_idx < len(step_frames) and step_frames[step_idx]:
|
||||
frames = sorted(step_frames[step_idx])
|
||||
per_token_frame.append(frames[len(frames) // 2])
|
||||
else:
|
||||
per_token_frame.append(None)
|
||||
|
||||
# Decode full text, split into words
|
||||
full_text = asr.tokenizer.decode(generated_ids[:emit_up_to])
|
||||
text_words = full_text.split()
|
||||
|
||||
# Map words to frames proportionally
|
||||
all_frames = [f for f in per_token_frame if f is not None]
|
||||
word_frame_pairs = []
|
||||
for wi, word in enumerate(text_words):
|
||||
if all_frames:
|
||||
frac = wi / max(len(text_words), 1)
|
||||
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
|
||||
frame = all_frames[frame_idx]
|
||||
else:
|
||||
frame = None
|
||||
word_frame_pairs.append((word, frame))
|
||||
|
||||
# Convert to ASRToken
|
||||
tokens = []
|
||||
for i, (text, frame) in enumerate(word_frame_pairs):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
if frame is not None and n_audio_tokens > 0:
|
||||
timestamp = (
|
||||
frame / n_audio_tokens * audio_duration
|
||||
+ state.cumulative_time_offset
|
||||
)
|
||||
else:
|
||||
timestamp = (
|
||||
(i / max(len(word_frame_pairs), 1)) * audio_duration
|
||||
+ state.cumulative_time_offset
|
||||
)
|
||||
|
||||
is_very_first_word = (i == 0 and state.committed_word_count == 0)
|
||||
display_text = text if is_very_first_word else " " + text
|
||||
|
||||
token = ASRToken(
|
||||
start=round(timestamp, 2),
|
||||
end=round(timestamp + 0.1, 2),
|
||||
text=display_text,
|
||||
speaker=state.speaker,
|
||||
detected_language=state.detected_language,
|
||||
).with_offset(state.global_time_offset)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
# -- silence / speaker / lifecycle --
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
all_tokens = []
|
||||
for _ in range(5):
|
||||
tokens, _ = self.process_iter(is_last=True)
|
||||
if not tokens:
|
||||
break
|
||||
all_tokens.extend(tokens)
|
||||
return all_tokens, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_len > 0:
|
||||
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||
self.state.audio_buffer = np.append(
|
||||
self.state.audio_buffer, gap_silence,
|
||||
)
|
||||
else:
|
||||
self.state = _SessionState()
|
||||
self.state.global_time_offset = silence_duration + offset
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.state = _SessionState()
|
||||
self.state.speaker = change_speaker.speaker
|
||||
self.state.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
return Transcript.from_tokens(tokens=self.buffer, sep='')
|
||||
|
||||
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
|
||||
try:
|
||||
self.state.audio_buffer = audio[:SAMPLE_RATE]
|
||||
self.process_iter(is_last=True)
|
||||
self.state = _SessionState()
|
||||
logger.info("Qwen3 MLX SimulStreaming processor warmed up")
|
||||
except Exception as e:
|
||||
logger.warning("Warmup failed: %s", e)
|
||||
self.state = _SessionState()
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
all_tokens = []
|
||||
for _ in range(5):
|
||||
tokens, _ = self.process_iter(is_last=True)
|
||||
if not tokens:
|
||||
break
|
||||
all_tokens.extend(tokens)
|
||||
return all_tokens, self.end
|
||||
@@ -622,9 +622,6 @@ class Qwen3SimulStreamingOnlineProcessor:
|
||||
thinker = asr.model.thinker
|
||||
|
||||
try:
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
@@ -719,7 +716,6 @@ class Qwen3SimulStreamingOnlineProcessor:
|
||||
return [], self.end
|
||||
|
||||
logger.info("Running SimulStreaming inference on %.2fs of audio (%.2fs new)", audio_duration, new_samples / self.SAMPLING_RATE)
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
try:
|
||||
timestamped_words = self._infer(is_last)
|
||||
@@ -727,6 +723,10 @@ class Qwen3SimulStreamingOnlineProcessor:
|
||||
logger.exception("Qwen3 SimulStreaming inference error: %s", e)
|
||||
return [], self.end
|
||||
|
||||
# Update the decode-budget marker after inference so _infer() sees the
|
||||
# true amount of newly arrived audio.
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
logger.info("SimulStreaming produced %d words", len(timestamped_words))
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
@@ -158,7 +158,9 @@ class Qwen3SimulKVASR:
|
||||
_patch_transformers_compat()
|
||||
|
||||
from qwen_asr.core.transformers_backend import (
|
||||
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
@@ -210,7 +212,18 @@ class Qwen3SimulKVASR:
|
||||
if path and Path(path).exists():
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
|
||||
if "alignment_heads_compact" in data:
|
||||
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
|
||||
elif "token_alignment_heads" in data:
|
||||
all_heads = [
|
||||
(int(h["layer"]), int(h["head"]))
|
||||
for h in data["token_alignment_heads"]
|
||||
]
|
||||
else:
|
||||
raise KeyError(
|
||||
"alignment_heads_compact/token_alignment_heads not found in "
|
||||
f"{path}"
|
||||
)
|
||||
heads = all_heads[:max_heads]
|
||||
logger.info("Loaded top %d alignment heads from %s", len(heads), path)
|
||||
return heads
|
||||
@@ -333,6 +346,21 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
def get_buffer(self) -> Transcript:
|
||||
return Transcript.from_tokens(tokens=self.buffer, sep='')
|
||||
|
||||
@staticmethod
|
||||
def _normalize_audio_embeds(audio_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""Keep cached audio embeddings in a consistent 2D layout."""
|
||||
if audio_embeds.dim() == 3:
|
||||
if audio_embeds.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Unexpected batched audio embeds shape: {tuple(audio_embeds.shape)}",
|
||||
)
|
||||
audio_embeds = audio_embeds[0]
|
||||
if audio_embeds.dim() == 1:
|
||||
audio_embeds = audio_embeds.unsqueeze(0)
|
||||
if audio_embeds.dim() != 2:
|
||||
raise ValueError(f"Unexpected audio embeds shape: {tuple(audio_embeds.shape)}")
|
||||
return audio_embeds
|
||||
|
||||
def _encode_audio(self) -> Tuple[torch.Tensor, int]:
|
||||
"""Encode full audio buffer, with caching for stable windows."""
|
||||
asr = self.asr
|
||||
@@ -364,8 +392,7 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
if audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds[0]
|
||||
audio_embeds = self._normalize_audio_embeds(audio_embeds)
|
||||
stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
|
||||
stable_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(stable_mel),
|
||||
@@ -389,8 +416,7 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
tail_embeds = asr.model.thinker.get_audio_features(
|
||||
tail_features, feature_attention_mask=tail_mask,
|
||||
)
|
||||
if tail_embeds.dim() == 3:
|
||||
tail_embeds = tail_embeds[0]
|
||||
tail_embeds = self._normalize_audio_embeds(tail_embeds)
|
||||
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
|
||||
else:
|
||||
audio_embeds = cached_prefix
|
||||
@@ -398,11 +424,10 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
if audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds[0]
|
||||
audio_embeds = self._normalize_audio_embeds(audio_embeds)
|
||||
|
||||
# Update cache
|
||||
cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0]
|
||||
cache.embeddings = audio_embeds.unsqueeze(0)
|
||||
cache.encoded_samples = len(state.audio_buffer)
|
||||
cache.encoded_mel_frames = total_mel_frames
|
||||
stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
|
||||
@@ -418,9 +443,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
state = self.state
|
||||
thinker = asr.model.thinker
|
||||
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
@@ -482,8 +504,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
|
||||
return [], self.end
|
||||
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
try:
|
||||
timestamped_words = self._infer(is_last)
|
||||
except Exception as e:
|
||||
@@ -491,6 +511,11 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
self.state.reset_kv()
|
||||
return [], self.end
|
||||
|
||||
# Advance the decode budget marker only after inference. Updating this
|
||||
# before _infer() makes new_audio_secs collapse to zero inside the
|
||||
# decoder loop and artificially caps generation to the 1-second path.
|
||||
self.state.last_infer_samples = len(self.state.audio_buffer)
|
||||
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
@@ -529,7 +554,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
use_cache=True,
|
||||
)
|
||||
kv_cache = out.past_key_values
|
||||
prompt_len = input_ids.shape[1]
|
||||
|
||||
# Step 4: Greedy decode with alignment head stopping
|
||||
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
|
||||
@@ -653,7 +677,6 @@ class Qwen3SimulKVOnlineProcessor:
|
||||
return []
|
||||
|
||||
# Strip metadata prefix (<asr_text> token)
|
||||
all_generated = torch.tensor(generated_ids, device=asr.device)
|
||||
num_gen = len(generated_ids)
|
||||
asr_text_id = asr.asr_text_token_id
|
||||
metadata_offset = 0
|
||||
|
||||
Reference in New Issue
Block a user