mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-21 16:40:35 +00:00
438 lines
17 KiB
Python
438 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""Run benchmark across all backend x model x policy combos for scatter plot.
|
|
|
|
Tests each configuration on long audio samples in two modes:
|
|
- Compute-unaware (speed=0): all audio dumped instantly, measures pure model accuracy
|
|
- Compute-aware (speed=1.0): real-time simulation, slow models lose audio
|
|
|
|
Usage:
|
|
python scripts/run_scatter_benchmark.py
|
|
python scripts/run_scatter_benchmark.py --aware # only compute-aware
|
|
python scripts/run_scatter_benchmark.py --unaware # only compute-unaware
|
|
python scripts/run_scatter_benchmark.py --plot-only results.json
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import gc
|
|
import json
|
|
import logging
|
|
import platform
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
logging.basicConfig(level=logging.WARNING)
|
|
for name in [
|
|
"whisperlivekit", "transformers", "torch", "httpx", "datasets",
|
|
"numexpr", "faster_whisper",
|
|
]:
|
|
logging.getLogger(name).setLevel(logging.ERROR)
|
|
|
|
|
|
LONG_SAMPLES_PATH = "~/.cache/whisperlivekit/benchmark_data/long_samples.json"
|
|
|
|
# ── All configurations to benchmark ──
|
|
|
|
COMBOS = [
|
|
# faster-whisper x LocalAgreement
|
|
{"backend": "faster-whisper", "model_size": "base", "policy": "localagreement",
|
|
"label": "fw LA base", "color": "#4a9eff", "marker": "o", "size": 100},
|
|
{"backend": "faster-whisper", "model_size": "small", "policy": "localagreement",
|
|
"label": "fw LA small", "color": "#4a9eff", "marker": "o", "size": 220},
|
|
# faster-whisper x SimulStreaming
|
|
{"backend": "faster-whisper", "model_size": "base", "policy": "simulstreaming",
|
|
"label": "fw SS base", "color": "#4a9eff", "marker": "s", "size": 100},
|
|
{"backend": "faster-whisper", "model_size": "small", "policy": "simulstreaming",
|
|
"label": "fw SS small", "color": "#4a9eff", "marker": "s", "size": 220},
|
|
# mlx-whisper x LocalAgreement
|
|
{"backend": "mlx-whisper", "model_size": "base", "policy": "localagreement",
|
|
"label": "mlx LA base", "color": "#4ecca3", "marker": "o", "size": 100},
|
|
{"backend": "mlx-whisper", "model_size": "small", "policy": "localagreement",
|
|
"label": "mlx LA small", "color": "#4ecca3", "marker": "o", "size": 220},
|
|
# mlx-whisper x SimulStreaming
|
|
{"backend": "mlx-whisper", "model_size": "base", "policy": "simulstreaming",
|
|
"label": "mlx SS base", "color": "#4ecca3", "marker": "s", "size": 100},
|
|
{"backend": "mlx-whisper", "model_size": "small", "policy": "simulstreaming",
|
|
"label": "mlx SS small", "color": "#4ecca3", "marker": "s", "size": 220},
|
|
# voxtral-mlx (4B, native streaming)
|
|
{"backend": "voxtral-mlx", "model_size": "", "policy": "",
|
|
"label": "voxtral mlx", "color": "#f5a623", "marker": "D", "size": 250},
|
|
]
|
|
|
|
|
|
def is_backend_available(backend):
|
|
try:
|
|
if backend == "faster-whisper":
|
|
import faster_whisper; return True # noqa
|
|
elif backend == "mlx-whisper":
|
|
import mlx_whisper; return True # noqa
|
|
elif backend == "whisper":
|
|
import whisper; return True # noqa
|
|
elif backend == "voxtral-mlx":
|
|
import mlx.core # noqa
|
|
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model; return True # noqa
|
|
elif backend == "voxtral":
|
|
from transformers import VoxtralRealtimeForConditionalGeneration; return True # noqa
|
|
elif backend in ("qwen3", "qwen3-simul"):
|
|
from whisperlivekit.qwen3_asr import _patch_transformers_compat
|
|
_patch_transformers_compat()
|
|
from qwen_asr import Qwen3ASRModel; return True # noqa
|
|
except (ImportError, Exception):
|
|
pass
|
|
return False
|
|
|
|
|
|
def get_system_info():
|
|
info = {"platform": platform.platform(), "machine": platform.machine()}
|
|
try:
|
|
info["cpu"] = subprocess.check_output(
|
|
["sysctl", "-n", "machdep.cpu.brand_string"], text=True).strip()
|
|
except Exception:
|
|
info["cpu"] = platform.processor()
|
|
try:
|
|
mem = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip())
|
|
info["ram_gb"] = round(mem / (1024**3))
|
|
except Exception:
|
|
info["ram_gb"] = None
|
|
return info
|
|
|
|
|
|
async def run_combo_on_samples(combo, samples, lang="en", speed=0):
|
|
"""Run one config on all samples, return averaged result.
|
|
|
|
Args:
|
|
speed: 0 = compute-unaware (instant dump), 1.0 = compute-aware (real-time)
|
|
"""
|
|
from whisperlivekit.core import TranscriptionEngine
|
|
from whisperlivekit.metrics import compute_wer
|
|
from whisperlivekit.test_harness import TestHarness, _engine_cache
|
|
|
|
kwargs = {"lan": lang, "pcm_input": True}
|
|
if combo["backend"]:
|
|
kwargs["backend"] = combo["backend"]
|
|
if combo["model_size"]:
|
|
kwargs["model_size"] = combo["model_size"]
|
|
if combo.get("policy"):
|
|
kwargs["backend_policy"] = combo["policy"]
|
|
|
|
TranscriptionEngine.reset()
|
|
_engine_cache.clear()
|
|
gc.collect()
|
|
|
|
total_ref_words, total_errors = 0, 0
|
|
total_infer_time, total_audio_time = 0.0, 0.0
|
|
n_ok = 0
|
|
|
|
for sample in samples:
|
|
try:
|
|
async with TestHarness(**kwargs) as h:
|
|
await h.feed(sample["path"], speed=speed)
|
|
await h.drain(max(5.0, sample["duration"] * 0.5))
|
|
state = await h.finish(timeout=120)
|
|
metrics = h.metrics
|
|
|
|
hypothesis = state.committed_text or state.text
|
|
wer_result = compute_wer(sample["reference"], hypothesis)
|
|
|
|
total_ref_words += wer_result["ref_words"]
|
|
total_errors += (wer_result["substitutions"] +
|
|
wer_result["insertions"] +
|
|
wer_result["deletions"])
|
|
|
|
# Use actual inference time from metrics, not wall clock
|
|
if metrics and metrics.transcription_durations:
|
|
total_infer_time += sum(metrics.transcription_durations)
|
|
total_audio_time += sample["duration"]
|
|
n_ok += 1
|
|
except Exception as e:
|
|
print(f" [WARN: {sample['name']} failed: {e}]", end="")
|
|
|
|
if n_ok == 0:
|
|
return None
|
|
|
|
weighted_wer = total_errors / max(total_ref_words, 1)
|
|
# Real RTF = actual inference time / audio duration
|
|
real_rtf = total_infer_time / total_audio_time if total_audio_time > 0 else 0
|
|
|
|
return {
|
|
"label": combo["label"],
|
|
"backend": combo["backend"],
|
|
"model_size": combo.get("model_size", ""),
|
|
"policy": combo.get("policy", ""),
|
|
"color": combo["color"],
|
|
"marker": combo["marker"],
|
|
"size": combo["size"],
|
|
"rtf": round(real_rtf, 4),
|
|
"wer_pct": round(weighted_wer * 100, 1),
|
|
"n_samples": n_ok,
|
|
}
|
|
|
|
|
|
async def run_all(combos, samples, lang="en", speed=0):
|
|
mode_label = "compute-aware" if speed > 0 else "compute-unaware"
|
|
results = []
|
|
for i, combo in enumerate(combos):
|
|
if not is_backend_available(combo["backend"]):
|
|
print(f" [{i+1}/{len(combos)}] SKIP {combo['label']} (not installed)")
|
|
continue
|
|
print(f" [{i+1}/{len(combos)}] {combo['label']} ({mode_label})...", end="", flush=True)
|
|
result = await run_combo_on_samples(combo, samples, lang, speed=speed)
|
|
if result:
|
|
results.append(result)
|
|
print(f" RTF={result['rtf']:.2f}x WER={result['wer_pct']:.1f}% ({result['n_samples']} samples)")
|
|
else:
|
|
print(" FAILED (no results)")
|
|
return results
|
|
|
|
|
|
def get_long_samples_for_lang(lang="en"):
|
|
"""Load long benchmark samples from long_samples.json, filtered by language."""
|
|
import os
|
|
path = os.path.expanduser(LONG_SAMPLES_PATH)
|
|
if not os.path.exists(path):
|
|
print(f"ERROR: Long samples file not found: {path}")
|
|
print("Please generate it first (see benchmark_data/README).")
|
|
sys.exit(1)
|
|
with open(path) as f:
|
|
all_samples = json.load(f)
|
|
samples = [s for s in all_samples if s["language"] == lang]
|
|
return [{"name": s["name"], "path": s["path"], "reference": s["reference"],
|
|
"duration": s["duration"]} for s in samples]
|
|
|
|
|
|
LANG_NAMES = {
|
|
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
|
|
"pt": "Portuguese", "it": "Italian", "nl": "Dutch", "pl": "Polish",
|
|
"zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ru": "Russian",
|
|
}
|
|
|
|
|
|
def generate_scatter(results, system_info, output_path, n_samples, lang="en",
|
|
mode="unaware", sample_duration=0.0):
|
|
"""Generate scatter plot.
|
|
|
|
Args:
|
|
mode: "unaware" or "aware" -- shown in title
|
|
sample_duration: total audio duration in seconds -- shown in title
|
|
"""
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
from matplotlib.lines import Line2D
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 7), facecolor="white")
|
|
ax.set_facecolor("#fafafa")
|
|
|
|
# Show ALL points on chart (no outlier exclusion)
|
|
main = results
|
|
slow = []
|
|
|
|
# Axis limits: fit all data
|
|
if main:
|
|
xmax = max(r["rtf"] for r in main) * 1.15
|
|
ymax = max(r["wer_pct"] for r in main) * 1.15 + 1
|
|
else:
|
|
xmax, ymax = 0.5, 10
|
|
xmax = max(xmax, 1.15) # always show the real-time line
|
|
ymax = max(ymax, 8)
|
|
|
|
# Sweet spot zone: RTF < 1.0 (real-time) and WER < 12%
|
|
sweet_x = min(1.0, xmax * 0.85)
|
|
sweet_y = min(12, ymax * 0.45)
|
|
rect = plt.Rectangle((0, 0), sweet_x, sweet_y, alpha=0.07, color="#4ecca3",
|
|
zorder=0, linewidth=0)
|
|
ax.add_patch(rect)
|
|
ax.text(sweet_x - 0.005, sweet_y - 0.15, "sweet spot", ha="right", va="top",
|
|
fontsize=10, color="#2ecc71", fontstyle="italic", fontweight="bold", alpha=0.5)
|
|
|
|
# Real-time limit line
|
|
ax.axvline(x=1.0, color="#e94560", linestyle="--", linewidth=1.5, alpha=0.4, zorder=1)
|
|
ax.text(1.02, ymax * 0.97, "real-time\nlimit", fontsize=8, color="#e94560",
|
|
va="top", alpha=0.6)
|
|
|
|
# Manual label offsets keyed by label name — hand-tuned
|
|
OFFSETS = {
|
|
"fw LA base": (8, 8),
|
|
"fw LA small": (8, 8),
|
|
"fw SS base": (-55, -14),
|
|
"fw SS small": (8, 8),
|
|
"mlx LA base": (8, 10),
|
|
"mlx LA small": (8, 8),
|
|
"mlx SS base": (-55, 8),
|
|
"mlx SS small": (-55, -5),
|
|
"voxtral mlx": (10, -14),
|
|
"qwen3 0.6B": (10, 8),
|
|
"qwen3-mlx 0.6B": (10, -14),
|
|
"qwen3-mlx 1.7B": (10, 8),
|
|
"fw LA large-v3": (8, -5),
|
|
"fw SS large-v3": (8, 5),
|
|
}
|
|
|
|
# Plot main points
|
|
for r in main:
|
|
ax.scatter(r["rtf"], r["wer_pct"], c=r["color"], marker=r["marker"],
|
|
s=r["size"], edgecolors="white", linewidths=1.0, zorder=5, alpha=0.85)
|
|
ox, oy = OFFSETS.get(r["label"], (8, -4))
|
|
ax.annotate(r["label"], (r["rtf"], r["wer_pct"]),
|
|
textcoords="offset points", xytext=(ox, oy),
|
|
fontsize=8.5, color="#333333", fontweight="medium")
|
|
|
|
# Note slow backends outside main view
|
|
if slow:
|
|
lines = []
|
|
for r in slow:
|
|
lines.append(f"{r['label']}: RTF={r['rtf']:.1f}x, WER={r['wer_pct']:.1f}%")
|
|
note = "Beyond real-time:\n" + "\n".join(lines)
|
|
ax.text(xmax * 0.97, ymax * 0.97, note, ha="right", va="top",
|
|
fontsize=7.5, color="#777777", fontstyle="italic",
|
|
bbox=dict(boxstyle="round,pad=0.4", facecolor="#f8f8f8",
|
|
edgecolor="#dddddd", alpha=0.9))
|
|
|
|
# Axes
|
|
ax.set_xlim(left=-0.01, right=xmax)
|
|
ax.set_ylim(bottom=0, top=ymax)
|
|
ax.set_xlabel("RTF (lower = faster)", fontsize=13, fontweight="bold", labelpad=8)
|
|
ax.set_ylabel("WER % (lower = more accurate)", fontsize=13, fontweight="bold", labelpad=8)
|
|
ax.grid(True, alpha=0.15, linestyle="-", color="#cccccc")
|
|
ax.tick_params(labelsize=10)
|
|
|
|
# Title
|
|
cpu = system_info.get("cpu", "unknown").replace("Apple ", "")
|
|
lang_name = LANG_NAMES.get(lang, lang.upper())
|
|
mode_label = "compute-unaware" if mode == "unaware" else "compute-aware"
|
|
dur_str = f"{sample_duration / 60:.0f}min" if sample_duration >= 60 else f"{sample_duration:.0f}s"
|
|
ax.set_title(
|
|
f"Speed vs Accuracy ({mode_label}) — {n_samples} {lang_name} samples, {dur_str} ({cpu})",
|
|
fontsize=14, fontweight="bold", pad=12)
|
|
|
|
# Legend — backends
|
|
backend_handles = []
|
|
seen = set()
|
|
for r in results:
|
|
if r["backend"] not in seen:
|
|
seen.add(r["backend"])
|
|
backend_handles.append(mpatches.Patch(color=r["color"], label=r["backend"]))
|
|
|
|
# Legend — shapes
|
|
marker_map = {"o": "LocalAgreement", "s": "SimulStreaming", "D": "Native streaming",
|
|
"h": "Batch + aligner"}
|
|
active = set(r["marker"] for r in results)
|
|
shape_handles = [
|
|
Line2D([0], [0], marker=m, color="#888", label=lbl,
|
|
markerfacecolor="#888", markersize=8, linestyle="None")
|
|
for m, lbl in marker_map.items() if m in active
|
|
]
|
|
# sizes
|
|
shape_handles += [
|
|
Line2D([0], [0], marker="o", color="#888", label="base",
|
|
markerfacecolor="#888", markersize=5, linestyle="None"),
|
|
Line2D([0], [0], marker="o", color="#888", label="small / 4B",
|
|
markerfacecolor="#888", markersize=9, linestyle="None"),
|
|
]
|
|
|
|
leg1 = ax.legend(handles=backend_handles, loc="upper left", fontsize=9,
|
|
framealpha=0.95, edgecolor="#ddd", title="Backend", title_fontsize=9)
|
|
ax.add_artist(leg1)
|
|
ax.legend(handles=shape_handles, loc="lower right", fontsize=8,
|
|
framealpha=0.95, edgecolor="#ddd", ncol=2)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight", pad_inches=0.15)
|
|
print(f"Saved {output_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--plot-only", default=None)
|
|
parser.add_argument("--lang", default="en", help="Language code (en, fr, es, de, ...)")
|
|
parser.add_argument("--output", "-o", default=None,
|
|
help="Output path prefix (mode suffix added automatically)")
|
|
parser.add_argument("--json-output", default=None,
|
|
help="JSON output path prefix (mode suffix added automatically)")
|
|
parser.add_argument("--aware", action="store_true",
|
|
help="Run only compute-aware mode (speed=1.0)")
|
|
parser.add_argument("--unaware", action="store_true",
|
|
help="Run only compute-unaware mode (speed=0)")
|
|
args = parser.parse_args()
|
|
|
|
lang = args.lang
|
|
|
|
# Determine which modes to run
|
|
if args.aware and args.unaware:
|
|
modes = ["unaware", "aware"]
|
|
elif args.aware:
|
|
modes = ["aware"]
|
|
elif args.unaware:
|
|
modes = ["unaware"]
|
|
else:
|
|
# Default: run both
|
|
modes = ["unaware", "aware"]
|
|
|
|
if args.plot_only:
|
|
data = json.load(open(args.plot_only))
|
|
mode = data.get("mode", "unaware")
|
|
output_path = args.output or f"benchmark_scatter_{lang}_{mode}.png"
|
|
generate_scatter(data["results"], data["system_info"], output_path,
|
|
data["n_samples"], data.get("lang", "en"),
|
|
mode=mode,
|
|
sample_duration=data.get("total_audio_s", 0))
|
|
return
|
|
|
|
print(f"Loading long {lang} samples from {LONG_SAMPLES_PATH}...")
|
|
samples = get_long_samples_for_lang(lang)
|
|
if not samples:
|
|
print(f"ERROR: No long samples for language '{lang}'")
|
|
sys.exit(1)
|
|
print(f"Using {len(samples)} samples: {[s['name'] for s in samples]}")
|
|
total_dur = sum(s["duration"] for s in samples)
|
|
print(f"Total audio: {total_dur:.0f}s ({total_dur / 60:.1f}min)\n")
|
|
|
|
# Filter combos to backends that support this language
|
|
from whisperlivekit.benchmark.compat import backend_supports_language
|
|
combos = [c for c in COMBOS if backend_supports_language(c["backend"], lang)]
|
|
|
|
system_info = get_system_info()
|
|
|
|
for mode in modes:
|
|
speed = 1.0 if mode == "aware" else 0
|
|
mode_label = "compute-aware" if mode == "aware" else "compute-unaware"
|
|
print(f"\n{'='*60}")
|
|
print(f" Running {mode_label} (speed={speed})")
|
|
print(f"{'='*60}\n")
|
|
|
|
t0 = time.time()
|
|
results = asyncio.run(run_all(combos, samples, lang, speed=speed))
|
|
total = time.time() - t0
|
|
|
|
# Save JSON
|
|
json_path = args.json_output or f"/tmp/bench_scatter_{lang}"
|
|
json_file = f"{json_path}_{mode}.json"
|
|
output_data = {
|
|
"system_info": system_info,
|
|
"lang": lang,
|
|
"mode": mode,
|
|
"speed": speed,
|
|
"n_samples": len(samples),
|
|
"sample_names": [s["name"] for s in samples],
|
|
"total_audio_s": round(total_dur, 1),
|
|
"total_benchmark_time_s": round(total, 1),
|
|
"results": results,
|
|
}
|
|
with open(json_file, "w") as f:
|
|
json.dump(output_data, f, indent=2)
|
|
print(f"\nJSON: {json_file} ({total:.0f}s total)")
|
|
|
|
# Generate scatter plot
|
|
output_base = args.output or f"benchmark_scatter_{lang}"
|
|
output_path = f"{output_base}_{mode}.png"
|
|
generate_scatter(results, system_info, output_path, len(samples), lang,
|
|
mode=mode, sample_duration=total_dur)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|