Files
WhisperLiveKit/scripts/run_scatter_benchmark.py
Quentin Fuxa ed503be140 qwen
2026-01-02 23:52:00 +01:00

438 lines
17 KiB
Python

#!/usr/bin/env python3
"""Run benchmark across all backend x model x policy combos for scatter plot.
Tests each configuration on long audio samples in two modes:
- Compute-unaware (speed=0): all audio dumped instantly, measures pure model accuracy
- Compute-aware (speed=1.0): real-time simulation, slow models lose audio
Usage:
python scripts/run_scatter_benchmark.py
python scripts/run_scatter_benchmark.py --aware # only compute-aware
python scripts/run_scatter_benchmark.py --unaware # only compute-unaware
python scripts/run_scatter_benchmark.py --plot-only results.json
"""
import argparse
import asyncio
import gc
import json
import logging
import platform
import subprocess
import sys
import time
import warnings
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.WARNING)
for name in [
"whisperlivekit", "transformers", "torch", "httpx", "datasets",
"numexpr", "faster_whisper",
]:
logging.getLogger(name).setLevel(logging.ERROR)
LONG_SAMPLES_PATH = "~/.cache/whisperlivekit/benchmark_data/long_samples.json"
# ── All configurations to benchmark ──
COMBOS = [
# faster-whisper x LocalAgreement
{"backend": "faster-whisper", "model_size": "base", "policy": "localagreement",
"label": "fw LA base", "color": "#4a9eff", "marker": "o", "size": 100},
{"backend": "faster-whisper", "model_size": "small", "policy": "localagreement",
"label": "fw LA small", "color": "#4a9eff", "marker": "o", "size": 220},
# faster-whisper x SimulStreaming
{"backend": "faster-whisper", "model_size": "base", "policy": "simulstreaming",
"label": "fw SS base", "color": "#4a9eff", "marker": "s", "size": 100},
{"backend": "faster-whisper", "model_size": "small", "policy": "simulstreaming",
"label": "fw SS small", "color": "#4a9eff", "marker": "s", "size": 220},
# mlx-whisper x LocalAgreement
{"backend": "mlx-whisper", "model_size": "base", "policy": "localagreement",
"label": "mlx LA base", "color": "#4ecca3", "marker": "o", "size": 100},
{"backend": "mlx-whisper", "model_size": "small", "policy": "localagreement",
"label": "mlx LA small", "color": "#4ecca3", "marker": "o", "size": 220},
# mlx-whisper x SimulStreaming
{"backend": "mlx-whisper", "model_size": "base", "policy": "simulstreaming",
"label": "mlx SS base", "color": "#4ecca3", "marker": "s", "size": 100},
{"backend": "mlx-whisper", "model_size": "small", "policy": "simulstreaming",
"label": "mlx SS small", "color": "#4ecca3", "marker": "s", "size": 220},
# voxtral-mlx (4B, native streaming)
{"backend": "voxtral-mlx", "model_size": "", "policy": "",
"label": "voxtral mlx", "color": "#f5a623", "marker": "D", "size": 250},
]
def is_backend_available(backend):
try:
if backend == "faster-whisper":
import faster_whisper; return True # noqa
elif backend == "mlx-whisper":
import mlx_whisper; return True # noqa
elif backend == "whisper":
import whisper; return True # noqa
elif backend == "voxtral-mlx":
import mlx.core # noqa
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model; return True # noqa
elif backend == "voxtral":
from transformers import VoxtralRealtimeForConditionalGeneration; return True # noqa
elif backend in ("qwen3", "qwen3-simul"):
from whisperlivekit.qwen3_asr import _patch_transformers_compat
_patch_transformers_compat()
from qwen_asr import Qwen3ASRModel; return True # noqa
except (ImportError, Exception):
pass
return False
def get_system_info():
info = {"platform": platform.platform(), "machine": platform.machine()}
try:
info["cpu"] = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"], text=True).strip()
except Exception:
info["cpu"] = platform.processor()
try:
mem = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip())
info["ram_gb"] = round(mem / (1024**3))
except Exception:
info["ram_gb"] = None
return info
async def run_combo_on_samples(combo, samples, lang="en", speed=0):
"""Run one config on all samples, return averaged result.
Args:
speed: 0 = compute-unaware (instant dump), 1.0 = compute-aware (real-time)
"""
from whisperlivekit.core import TranscriptionEngine
from whisperlivekit.metrics import compute_wer
from whisperlivekit.test_harness import TestHarness, _engine_cache
kwargs = {"lan": lang, "pcm_input": True}
if combo["backend"]:
kwargs["backend"] = combo["backend"]
if combo["model_size"]:
kwargs["model_size"] = combo["model_size"]
if combo.get("policy"):
kwargs["backend_policy"] = combo["policy"]
TranscriptionEngine.reset()
_engine_cache.clear()
gc.collect()
total_ref_words, total_errors = 0, 0
total_infer_time, total_audio_time = 0.0, 0.0
n_ok = 0
for sample in samples:
try:
async with TestHarness(**kwargs) as h:
await h.feed(sample["path"], speed=speed)
await h.drain(max(5.0, sample["duration"] * 0.5))
state = await h.finish(timeout=120)
metrics = h.metrics
hypothesis = state.committed_text or state.text
wer_result = compute_wer(sample["reference"], hypothesis)
total_ref_words += wer_result["ref_words"]
total_errors += (wer_result["substitutions"] +
wer_result["insertions"] +
wer_result["deletions"])
# Use actual inference time from metrics, not wall clock
if metrics and metrics.transcription_durations:
total_infer_time += sum(metrics.transcription_durations)
total_audio_time += sample["duration"]
n_ok += 1
except Exception as e:
print(f" [WARN: {sample['name']} failed: {e}]", end="")
if n_ok == 0:
return None
weighted_wer = total_errors / max(total_ref_words, 1)
# Real RTF = actual inference time / audio duration
real_rtf = total_infer_time / total_audio_time if total_audio_time > 0 else 0
return {
"label": combo["label"],
"backend": combo["backend"],
"model_size": combo.get("model_size", ""),
"policy": combo.get("policy", ""),
"color": combo["color"],
"marker": combo["marker"],
"size": combo["size"],
"rtf": round(real_rtf, 4),
"wer_pct": round(weighted_wer * 100, 1),
"n_samples": n_ok,
}
async def run_all(combos, samples, lang="en", speed=0):
mode_label = "compute-aware" if speed > 0 else "compute-unaware"
results = []
for i, combo in enumerate(combos):
if not is_backend_available(combo["backend"]):
print(f" [{i+1}/{len(combos)}] SKIP {combo['label']} (not installed)")
continue
print(f" [{i+1}/{len(combos)}] {combo['label']} ({mode_label})...", end="", flush=True)
result = await run_combo_on_samples(combo, samples, lang, speed=speed)
if result:
results.append(result)
print(f" RTF={result['rtf']:.2f}x WER={result['wer_pct']:.1f}% ({result['n_samples']} samples)")
else:
print(" FAILED (no results)")
return results
def get_long_samples_for_lang(lang="en"):
"""Load long benchmark samples from long_samples.json, filtered by language."""
import os
path = os.path.expanduser(LONG_SAMPLES_PATH)
if not os.path.exists(path):
print(f"ERROR: Long samples file not found: {path}")
print("Please generate it first (see benchmark_data/README).")
sys.exit(1)
with open(path) as f:
all_samples = json.load(f)
samples = [s for s in all_samples if s["language"] == lang]
return [{"name": s["name"], "path": s["path"], "reference": s["reference"],
"duration": s["duration"]} for s in samples]
LANG_NAMES = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"pt": "Portuguese", "it": "Italian", "nl": "Dutch", "pl": "Polish",
"zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ru": "Russian",
}
def generate_scatter(results, system_info, output_path, n_samples, lang="en",
mode="unaware", sample_duration=0.0):
"""Generate scatter plot.
Args:
mode: "unaware" or "aware" -- shown in title
sample_duration: total audio duration in seconds -- shown in title
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
fig, ax = plt.subplots(figsize=(12, 7), facecolor="white")
ax.set_facecolor("#fafafa")
# Show ALL points on chart (no outlier exclusion)
main = results
slow = []
# Axis limits: fit all data
if main:
xmax = max(r["rtf"] for r in main) * 1.15
ymax = max(r["wer_pct"] for r in main) * 1.15 + 1
else:
xmax, ymax = 0.5, 10
xmax = max(xmax, 1.15) # always show the real-time line
ymax = max(ymax, 8)
# Sweet spot zone: RTF < 1.0 (real-time) and WER < 12%
sweet_x = min(1.0, xmax * 0.85)
sweet_y = min(12, ymax * 0.45)
rect = plt.Rectangle((0, 0), sweet_x, sweet_y, alpha=0.07, color="#4ecca3",
zorder=0, linewidth=0)
ax.add_patch(rect)
ax.text(sweet_x - 0.005, sweet_y - 0.15, "sweet spot", ha="right", va="top",
fontsize=10, color="#2ecc71", fontstyle="italic", fontweight="bold", alpha=0.5)
# Real-time limit line
ax.axvline(x=1.0, color="#e94560", linestyle="--", linewidth=1.5, alpha=0.4, zorder=1)
ax.text(1.02, ymax * 0.97, "real-time\nlimit", fontsize=8, color="#e94560",
va="top", alpha=0.6)
# Manual label offsets keyed by label name — hand-tuned
OFFSETS = {
"fw LA base": (8, 8),
"fw LA small": (8, 8),
"fw SS base": (-55, -14),
"fw SS small": (8, 8),
"mlx LA base": (8, 10),
"mlx LA small": (8, 8),
"mlx SS base": (-55, 8),
"mlx SS small": (-55, -5),
"voxtral mlx": (10, -14),
"qwen3 0.6B": (10, 8),
"qwen3-mlx 0.6B": (10, -14),
"qwen3-mlx 1.7B": (10, 8),
"fw LA large-v3": (8, -5),
"fw SS large-v3": (8, 5),
}
# Plot main points
for r in main:
ax.scatter(r["rtf"], r["wer_pct"], c=r["color"], marker=r["marker"],
s=r["size"], edgecolors="white", linewidths=1.0, zorder=5, alpha=0.85)
ox, oy = OFFSETS.get(r["label"], (8, -4))
ax.annotate(r["label"], (r["rtf"], r["wer_pct"]),
textcoords="offset points", xytext=(ox, oy),
fontsize=8.5, color="#333333", fontweight="medium")
# Note slow backends outside main view
if slow:
lines = []
for r in slow:
lines.append(f"{r['label']}: RTF={r['rtf']:.1f}x, WER={r['wer_pct']:.1f}%")
note = "Beyond real-time:\n" + "\n".join(lines)
ax.text(xmax * 0.97, ymax * 0.97, note, ha="right", va="top",
fontsize=7.5, color="#777777", fontstyle="italic",
bbox=dict(boxstyle="round,pad=0.4", facecolor="#f8f8f8",
edgecolor="#dddddd", alpha=0.9))
# Axes
ax.set_xlim(left=-0.01, right=xmax)
ax.set_ylim(bottom=0, top=ymax)
ax.set_xlabel("RTF (lower = faster)", fontsize=13, fontweight="bold", labelpad=8)
ax.set_ylabel("WER % (lower = more accurate)", fontsize=13, fontweight="bold", labelpad=8)
ax.grid(True, alpha=0.15, linestyle="-", color="#cccccc")
ax.tick_params(labelsize=10)
# Title
cpu = system_info.get("cpu", "unknown").replace("Apple ", "")
lang_name = LANG_NAMES.get(lang, lang.upper())
mode_label = "compute-unaware" if mode == "unaware" else "compute-aware"
dur_str = f"{sample_duration / 60:.0f}min" if sample_duration >= 60 else f"{sample_duration:.0f}s"
ax.set_title(
f"Speed vs Accuracy ({mode_label}) — {n_samples} {lang_name} samples, {dur_str} ({cpu})",
fontsize=14, fontweight="bold", pad=12)
# Legend — backends
backend_handles = []
seen = set()
for r in results:
if r["backend"] not in seen:
seen.add(r["backend"])
backend_handles.append(mpatches.Patch(color=r["color"], label=r["backend"]))
# Legend — shapes
marker_map = {"o": "LocalAgreement", "s": "SimulStreaming", "D": "Native streaming",
"h": "Batch + aligner"}
active = set(r["marker"] for r in results)
shape_handles = [
Line2D([0], [0], marker=m, color="#888", label=lbl,
markerfacecolor="#888", markersize=8, linestyle="None")
for m, lbl in marker_map.items() if m in active
]
# sizes
shape_handles += [
Line2D([0], [0], marker="o", color="#888", label="base",
markerfacecolor="#888", markersize=5, linestyle="None"),
Line2D([0], [0], marker="o", color="#888", label="small / 4B",
markerfacecolor="#888", markersize=9, linestyle="None"),
]
leg1 = ax.legend(handles=backend_handles, loc="upper left", fontsize=9,
framealpha=0.95, edgecolor="#ddd", title="Backend", title_fontsize=9)
ax.add_artist(leg1)
ax.legend(handles=shape_handles, loc="lower right", fontsize=8,
framealpha=0.95, edgecolor="#ddd", ncol=2)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight", pad_inches=0.15)
print(f"Saved {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--plot-only", default=None)
parser.add_argument("--lang", default="en", help="Language code (en, fr, es, de, ...)")
parser.add_argument("--output", "-o", default=None,
help="Output path prefix (mode suffix added automatically)")
parser.add_argument("--json-output", default=None,
help="JSON output path prefix (mode suffix added automatically)")
parser.add_argument("--aware", action="store_true",
help="Run only compute-aware mode (speed=1.0)")
parser.add_argument("--unaware", action="store_true",
help="Run only compute-unaware mode (speed=0)")
args = parser.parse_args()
lang = args.lang
# Determine which modes to run
if args.aware and args.unaware:
modes = ["unaware", "aware"]
elif args.aware:
modes = ["aware"]
elif args.unaware:
modes = ["unaware"]
else:
# Default: run both
modes = ["unaware", "aware"]
if args.plot_only:
data = json.load(open(args.plot_only))
mode = data.get("mode", "unaware")
output_path = args.output or f"benchmark_scatter_{lang}_{mode}.png"
generate_scatter(data["results"], data["system_info"], output_path,
data["n_samples"], data.get("lang", "en"),
mode=mode,
sample_duration=data.get("total_audio_s", 0))
return
print(f"Loading long {lang} samples from {LONG_SAMPLES_PATH}...")
samples = get_long_samples_for_lang(lang)
if not samples:
print(f"ERROR: No long samples for language '{lang}'")
sys.exit(1)
print(f"Using {len(samples)} samples: {[s['name'] for s in samples]}")
total_dur = sum(s["duration"] for s in samples)
print(f"Total audio: {total_dur:.0f}s ({total_dur / 60:.1f}min)\n")
# Filter combos to backends that support this language
from whisperlivekit.benchmark.compat import backend_supports_language
combos = [c for c in COMBOS if backend_supports_language(c["backend"], lang)]
system_info = get_system_info()
for mode in modes:
speed = 1.0 if mode == "aware" else 0
mode_label = "compute-aware" if mode == "aware" else "compute-unaware"
print(f"\n{'='*60}")
print(f" Running {mode_label} (speed={speed})")
print(f"{'='*60}\n")
t0 = time.time()
results = asyncio.run(run_all(combos, samples, lang, speed=speed))
total = time.time() - t0
# Save JSON
json_path = args.json_output or f"/tmp/bench_scatter_{lang}"
json_file = f"{json_path}_{mode}.json"
output_data = {
"system_info": system_info,
"lang": lang,
"mode": mode,
"speed": speed,
"n_samples": len(samples),
"sample_names": [s["name"] for s in samples],
"total_audio_s": round(total_dur, 1),
"total_benchmark_time_s": round(total, 1),
"results": results,
}
with open(json_file, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nJSON: {json_file} ({total:.0f}s total)")
# Generate scatter plot
output_base = args.output or f"benchmark_scatter_{lang}"
output_path = f"{output_base}_{mode}.png"
generate_scatter(results, system_info, output_path, len(samples), lang,
mode=mode, sample_duration=total_dur)
if __name__ == "__main__":
main()