diff --git a/.gitignore b/.gitignore
index a015198..ecfdcd4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -119,9 +119,11 @@ run_*.sh
*.pt
# Debug & testing
-test_*.py
+/test_*.py
+!test_backend_offline.py
launch.json
.DS_Store
-test/*
+/test/
+!tests/
nllb-200-distilled-600M-ctranslate2/*
*.mp3
\ No newline at end of file
diff --git a/BENCHMARK.md b/BENCHMARK.md
new file mode 100644
index 0000000..81239f0
--- /dev/null
+++ b/BENCHMARK.md
@@ -0,0 +1,205 @@
+# WhisperLiveKit Benchmark Report
+
+Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon.
+All tests run through the full AudioProcessor pipeline (same code path as production WebSocket).
+
+## Test Environment
+
+| Property | Value |
+|----------|-------|
+| Hardware | Apple M4, 32 GB RAM |
+| OS | macOS 25.3.0 (arm64) |
+| Python | 3.13 |
+| faster-whisper | 1.2.1 |
+| mlx-whisper | installed (via mlx) |
+| Voxtral MLX | native MLX backend |
+| Voxtral (HF) | transformers-based |
+| VAC (Silero VAD) | enabled unless noted |
+| Chunk size | 100 ms |
+| Pacing | no-realtime (as fast as possible) |
+
+## Audio Test Files
+
+| File | Duration | Language | Speakers | Description |
+|------|----------|----------|----------|-------------|
+| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses |
+| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps |
+| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation |
+
+Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified.
+
+---
+
+## Results
+
+### English -- Short (7.2 s, 1 speaker)
+
+| Backend | Policy | Model | RTF | WER | Timestamp MAE |
+|---------|--------|-------|-----|-----|---------------|
+| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s |
+| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s |
+| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s |
+| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s |
+| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s |
+| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s |
+| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s |
+| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s |
+| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s |
+| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s |
+
+### English -- Multi-speaker (30.0 s, 3 speakers)
+
+| Backend | Policy | Model | RTF | WER | Timestamp MAE |
+|---------|--------|-------|-----|-----|---------------|
+| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s |
+| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s |
+| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s |
+| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s |
+| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s |
+| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s |
+| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s |
+| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s |
+| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s |
+| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s |
+
+
+
+
+
+
+
+
+
+### French (16.3 s, 1 speaker, `--language fr`)
+
+| Backend | Policy | Model | RTF | WER | Timestamp MAE |
+|---------|--------|-------|-----|-----|---------------|
+| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s |
+| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s |
+| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s |
+| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s |
+| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* |
+| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s |
+| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s |
+| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s |
+| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s |
+| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s |
+
+\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem.
+
+**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps.
+
+---
+
+## Model Size Comparison (base vs small)
+
+| | base | small | Observation |
+|--|------|-------|-------------|
+| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower |
+| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base |
+| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio |
+| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo |
+| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps |
+
+In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages.
+
+---
+
+## Key Findings
+
+### Speed (RTF = processing time / audio duration, lower is better)
+
+1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds.
+2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed.
+3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time.
+4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead.
+5. The **small** model is 2-3x slower than base across all backends.
+
+### Accuracy (WER = Word Error Rate, lower is better)
+
+1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%.
+2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments.
+3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run.
+4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER.
+
+### Timestamps (MAE = Mean Absolute Error on word start times)
+
+1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE).
+2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications.
+3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`.
+4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file).
+
+### VAC (Voice Activity Classification) Impact
+
+| Backend | Policy | VAC | 7s English WER | 30s English WER |
+|---------|--------|-----|----------------|-----------------|
+| faster-whisper | LocalAgreement | on | 21.1% | 44.7% |
+| faster-whisper | LocalAgreement | off | 100.0% | 100.0% |
+| voxtral-mlx | voxtral | on | 0.0% | 9.2% |
+| voxtral-mlx | voxtral | off | 0.0% | 9.2% |
+
+- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output.
+- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments.
+
+---
+
+## Recommendations
+
+| Use Case | Backend | Policy | Model | Notes |
+|----------|---------|--------|-------|-------|
+| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER |
+| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER |
+| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast |
+| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF |
+| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles |
+| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response |
+
+---
+
+## Caveats
+
+- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions.
+- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine.
+- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU.
+
+---
+
+## Reproducing These Benchmarks
+
+```bash
+# Install test dependencies
+pip install -e ".[test]"
+
+# Single backend test
+python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime
+
+# With a specific language
+python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime
+
+# Multi-backend auto-detect benchmark
+python test_backend_offline.py --benchmark --no-realtime
+
+# Export to JSON
+python test_backend_offline.py --benchmark --no-realtime --json results.json
+
+# Test with your own audio
+python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime
+```
+
+The benchmark harness computes WER and timestamp accuracy automatically when ground truth
+`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format.
+
+---
+
+## Help Us Benchmark on More Hardware
+
+These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc.
+
+If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get.
+
+What we are especially interested in:
+- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper
+- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx
+- **Medium and large-v3 models** (we only tested base and small so far)
+- **Longer audio files** or domain-specific audio (medical, legal, call center)
+- **Other languages** beyond English and French
diff --git a/README.md b/README.md
index 948f015..f97ba06 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
+- [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) (2025) - 4B-parameter multilingual speech model by Mistral AI
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
@@ -75,15 +76,42 @@ Go to `chrome-extension` for instructions.
|-----------|-------------|
| **Windows/Linux optimizations** | `faster-whisper` |
| **Apple Silicon optimizations** | `mlx-whisper` |
+| **Voxtral (multilingual, auto-detect)** | `transformers torch` (or use built-in `voxtral-mlx` on Apple Silicon) |
| **Translation** | `nllw` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| OpenAI API | `openai` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
-See **Parameters & Configuration** below on how to use them.
+See **Parameters & Configuration** below on how to use them.
+
+
+
+
+
+See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more.
+We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
+### Voxtral Backend
+
+WhisperLiveKit supports [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602),
+a 4B-parameter speech model from Mistral AI that natively handles 100+ languages with automatic
+language detection. Whisper also supports auto-detection (`--language auto`), but Voxtral's per-chunk
+detection is more reliable and does not bias towards English.
+
+```bash
+# Apple Silicon (native MLX, recommended)
+wlk --backend voxtral-mlx
+
+# Linux/GPU (HuggingFace transformers)
+pip install transformers torch
+wlk --backend voxtral
+```
+
+Voxtral uses its own streaming policy and does not use LocalAgreement or SimulStreaming.
+See [BENCHMARK.md](BENCHMARK.md) for performance numbers.
+
### Usage Examples
**Command-line Interface**: Start the transcription server with various options:
@@ -92,8 +120,11 @@ See **Parameters & Configuration** below on how to use them.
# Large model and translate from french to danish
wlk --model large-v3 --language fr --target-language da
-# Diarization and server listening on */80
+# Diarization and server listening on */80
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
+
+# Voxtral multilingual (auto-detects language)
+wlk --backend voxtral-mlx
```
@@ -151,7 +182,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
-| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
+| `--backend` | ASR backend selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. Options: `mlx-whisper`, `faster-whisper`, `whisper`, `openai-api` (LocalAgreement only), `voxtral-mlx` (Apple Silicon), `voxtral` (HuggingFace) | `auto` |
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
@@ -271,5 +302,29 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
-## 🔮 Use Cases
+## Testing & Benchmarks
+
+WhisperLiveKit includes a unit test suite and an offline benchmark harness.
+
+```bash
+# Install test dependencies
+pip install -e ".[test]"
+
+# Run unit tests (no model download required)
+pytest tests/ -v
+
+# Benchmark a single backend
+python test_backend_offline.py --backend faster-whisper --no-realtime
+
+# Benchmark all installed backends
+python test_backend_offline.py --benchmark --no-realtime
+
+# Export benchmark results as JSON
+python test_backend_offline.py --benchmark --no-realtime --json results.json
+```
+
+See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and
+timestamp accuracy on Apple Silicon.
+
+## Use Cases
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
diff --git a/audio_tests/00_00_07_english_1_speaker.transcript.json b/audio_tests/00_00_07_english_1_speaker.transcript.json
new file mode 100644
index 0000000..43ca785
--- /dev/null
+++ b/audio_tests/00_00_07_english_1_speaker.transcript.json
@@ -0,0 +1,97 @@
+[
+ {
+ "word": "This",
+ "start": 0.0,
+ "end": 0.24
+ },
+ {
+ "word": "is",
+ "start": 0.24,
+ "end": 0.56
+ },
+ {
+ "word": "a",
+ "start": 0.56,
+ "end": 0.76
+ },
+ {
+ "word": "transcription",
+ "start": 0.76,
+ "end": 1.32
+ },
+ {
+ "word": "test.",
+ "start": 1.32,
+ "end": 2.0
+ },
+ {
+ "word": "We",
+ "start": 2.4,
+ "end": 2.5
+ },
+ {
+ "word": "want",
+ "start": 2.5,
+ "end": 2.66
+ },
+ {
+ "word": "to",
+ "start": 2.66,
+ "end": 2.84
+ },
+ {
+ "word": "see",
+ "start": 2.84,
+ "end": 3.1
+ },
+ {
+ "word": "if",
+ "start": 3.1,
+ "end": 3.34
+ },
+ {
+ "word": "we",
+ "start": 3.34,
+ "end": 3.5
+ },
+ {
+ "word": "can",
+ "start": 3.5,
+ "end": 3.68
+ },
+ {
+ "word": "use",
+ "start": 3.68,
+ "end": 4.04
+ },
+ {
+ "word": "smaller",
+ "start": 4.04,
+ "end": 4.76
+ },
+ {
+ "word": "chunks.",
+ "start": 4.76,
+ "end": 5.16
+ },
+ {
+ "word": "What",
+ "start": 6.06,
+ "end": 6.32
+ },
+ {
+ "word": "do",
+ "start": 6.32,
+ "end": 6.44
+ },
+ {
+ "word": "you",
+ "start": 6.44,
+ "end": 6.58
+ },
+ {
+ "word": "think?",
+ "start": 6.58,
+ "end": 6.84
+ }
+]
\ No newline at end of file
diff --git a/audio_tests/00_00_16_french_1_speaker.transcript.json b/audio_tests/00_00_16_french_1_speaker.transcript.json
new file mode 100644
index 0000000..07c0b31
--- /dev/null
+++ b/audio_tests/00_00_16_french_1_speaker.transcript.json
@@ -0,0 +1,177 @@
+[
+ {
+ "word": "Ok,",
+ "start": 2.02,
+ "end": 2.38
+ },
+ {
+ "word": "là ",
+ "start": 2.52,
+ "end": 2.58
+ },
+ {
+ "word": "c",
+ "start": 2.58,
+ "end": 2.74
+ },
+ {
+ "word": "'est",
+ "start": 2.74,
+ "end": 2.76
+ },
+ {
+ "word": "un",
+ "start": 2.76,
+ "end": 2.86
+ },
+ {
+ "word": "test,",
+ "start": 2.86,
+ "end": 3.2
+ },
+ {
+ "word": "on",
+ "start": 3.34,
+ "end": 3.34
+ },
+ {
+ "word": "veut",
+ "start": 3.34,
+ "end": 3.48
+ },
+ {
+ "word": "voir",
+ "start": 3.48,
+ "end": 3.86
+ },
+ {
+ "word": "si",
+ "start": 3.86,
+ "end": 4.14
+ },
+ {
+ "word": "ça",
+ "start": 4.14,
+ "end": 4.26
+ },
+ {
+ "word": "arrive",
+ "start": 4.26,
+ "end": 4.36
+ },
+ {
+ "word": "Ã ",
+ "start": 4.36,
+ "end": 4.5
+ },
+ {
+ "word": "capté",
+ "start": 4.5,
+ "end": 4.78
+ },
+ {
+ "word": "le",
+ "start": 4.78,
+ "end": 4.9
+ },
+ {
+ "word": "silence.",
+ "start": 4.9,
+ "end": 5.44
+ },
+ {
+ "word": "LÃ ",
+ "start": 9.24,
+ "end": 9.6
+ },
+ {
+ "word": "il",
+ "start": 9.6,
+ "end": 9.78
+ },
+ {
+ "word": "est",
+ "start": 9.78,
+ "end": 9.84
+ },
+ {
+ "word": "une",
+ "start": 9.84,
+ "end": 9.96
+ },
+ {
+ "word": "telle",
+ "start": 9.96,
+ "end": 10.12
+ },
+ {
+ "word": "seconde",
+ "start": 10.12,
+ "end": 10.38
+ },
+ {
+ "word": "de",
+ "start": 10.38,
+ "end": 10.48
+ },
+ {
+ "word": "silence",
+ "start": 10.48,
+ "end": 10.78
+ },
+ {
+ "word": "et",
+ "start": 10.78,
+ "end": 11.06
+ },
+ {
+ "word": "je",
+ "start": 11.06,
+ "end": 11.16
+ },
+ {
+ "word": "vous",
+ "start": 11.16,
+ "end": 11.32
+ },
+ {
+ "word": "parle.",
+ "start": 11.32,
+ "end": 11.68
+ },
+ {
+ "word": "Et",
+ "start": 13.28,
+ "end": 13.64
+ },
+ {
+ "word": "voilà ,",
+ "start": 13.64,
+ "end": 13.96
+ },
+ {
+ "word": "allez",
+ "start": 14.36,
+ "end": 14.62
+ },
+ {
+ "word": "on",
+ "start": 14.62,
+ "end": 14.78
+ },
+ {
+ "word": "va",
+ "start": 14.78,
+ "end": 14.88
+ },
+ {
+ "word": "tester",
+ "start": 14.88,
+ "end": 15.06
+ },
+ {
+ "word": "ça.",
+ "start": 15.06,
+ "end": 15.36
+ }
+]
\ No newline at end of file
diff --git a/audio_tests/00_00_30_english_3_speakers.transcript.json b/audio_tests/00_00_30_english_3_speakers.transcript.json
new file mode 100644
index 0000000..bb9d097
--- /dev/null
+++ b/audio_tests/00_00_30_english_3_speakers.transcript.json
@@ -0,0 +1,382 @@
+[
+ {
+ "word": "Transcription",
+ "start": 0.0,
+ "end": 0.6
+ },
+ {
+ "word": "technology",
+ "start": 0.6,
+ "end": 1.24
+ },
+ {
+ "word": "has",
+ "start": 1.24,
+ "end": 1.5
+ },
+ {
+ "word": "improved",
+ "start": 1.5,
+ "end": 1.96
+ },
+ {
+ "word": "so",
+ "start": 1.96,
+ "end": 2.32
+ },
+ {
+ "word": "much",
+ "start": 2.32,
+ "end": 2.68
+ },
+ {
+ "word": "in",
+ "start": 2.68,
+ "end": 2.94
+ },
+ {
+ "word": "the",
+ "start": 2.94,
+ "end": 3.02
+ },
+ {
+ "word": "past",
+ "start": 3.02,
+ "end": 3.24
+ },
+ {
+ "word": "few",
+ "start": 3.24,
+ "end": 3.5
+ },
+ {
+ "word": "years.",
+ "start": 3.5,
+ "end": 3.96
+ },
+ {
+ "word": "Have",
+ "start": 4.56,
+ "end": 4.74
+ },
+ {
+ "word": "you",
+ "start": 4.74,
+ "end": 4.9
+ },
+ {
+ "word": "noticed",
+ "start": 4.9,
+ "end": 5.26
+ },
+ {
+ "word": "how",
+ "start": 5.26,
+ "end": 5.52
+ },
+ {
+ "word": "accurate",
+ "start": 5.52,
+ "end": 6.08
+ },
+ {
+ "word": "real",
+ "start": 6.08,
+ "end": 6.42
+ },
+ {
+ "word": "-time",
+ "start": 6.42,
+ "end": 6.74
+ },
+ {
+ "word": "speech",
+ "start": 6.74,
+ "end": 7.24
+ },
+ {
+ "word": "to",
+ "start": 7.24,
+ "end": 7.46
+ },
+ {
+ "word": "text",
+ "start": 7.46,
+ "end": 7.78
+ },
+ {
+ "word": "is",
+ "start": 7.78,
+ "end": 8.0
+ },
+ {
+ "word": "now?",
+ "start": 8.0,
+ "end": 8.3
+ },
+ {
+ "word": "Absolutely.",
+ "start": 8.7,
+ "end": 9.16
+ },
+ {
+ "word": "I",
+ "start": 10.04,
+ "end": 10.38
+ },
+ {
+ "word": "use",
+ "start": 10.38,
+ "end": 10.56
+ },
+ {
+ "word": "it",
+ "start": 10.56,
+ "end": 10.76
+ },
+ {
+ "word": "all",
+ "start": 10.76,
+ "end": 10.9
+ },
+ {
+ "word": "the",
+ "start": 10.9,
+ "end": 11.04
+ },
+ {
+ "word": "time",
+ "start": 11.04,
+ "end": 11.32
+ },
+ {
+ "word": "for",
+ "start": 11.32,
+ "end": 11.54
+ },
+ {
+ "word": "taking",
+ "start": 11.54,
+ "end": 11.86
+ },
+ {
+ "word": "notes",
+ "start": 11.86,
+ "end": 12.16
+ },
+ {
+ "word": "during",
+ "start": 12.16,
+ "end": 12.54
+ },
+ {
+ "word": "meetings.",
+ "start": 12.54,
+ "end": 12.94
+ },
+ {
+ "word": "It's",
+ "start": 13.6,
+ "end": 13.8
+ },
+ {
+ "word": "amazing",
+ "start": 13.8,
+ "end": 14.1
+ },
+ {
+ "word": "how",
+ "start": 14.1,
+ "end": 14.48
+ },
+ {
+ "word": "it",
+ "start": 14.48,
+ "end": 14.62
+ },
+ {
+ "word": "can",
+ "start": 14.62,
+ "end": 14.74
+ },
+ {
+ "word": "recognise",
+ "start": 14.74,
+ "end": 15.24
+ },
+ {
+ "word": "different",
+ "start": 15.24,
+ "end": 15.68
+ },
+ {
+ "word": "speakers",
+ "start": 15.68,
+ "end": 16.16
+ },
+ {
+ "word": "and",
+ "start": 16.16,
+ "end": 16.8
+ },
+ {
+ "word": "even",
+ "start": 16.8,
+ "end": 17.1
+ },
+ {
+ "word": "add",
+ "start": 17.1,
+ "end": 17.44
+ },
+ {
+ "word": "punctuation.",
+ "start": 17.44,
+ "end": 18.36
+ },
+ {
+ "word": "Yeah,",
+ "start": 18.88,
+ "end": 19.16
+ },
+ {
+ "word": "but",
+ "start": 19.36,
+ "end": 19.52
+ },
+ {
+ "word": "sometimes",
+ "start": 19.52,
+ "end": 20.16
+ },
+ {
+ "word": "noise",
+ "start": 20.16,
+ "end": 20.54
+ },
+ {
+ "word": "can",
+ "start": 20.54,
+ "end": 20.8
+ },
+ {
+ "word": "still",
+ "start": 20.8,
+ "end": 21.1
+ },
+ {
+ "word": "cause",
+ "start": 21.1,
+ "end": 21.44
+ },
+ {
+ "word": "mistakes.",
+ "start": 21.44,
+ "end": 21.94
+ },
+ {
+ "word": "Does",
+ "start": 22.68,
+ "end": 22.9
+ },
+ {
+ "word": "this",
+ "start": 22.9,
+ "end": 23.12
+ },
+ {
+ "word": "system",
+ "start": 23.12,
+ "end": 23.46
+ },
+ {
+ "word": "handle",
+ "start": 23.46,
+ "end": 23.88
+ },
+ {
+ "word": "that",
+ "start": 23.88,
+ "end": 24.12
+ },
+ {
+ "word": "well?",
+ "start": 24.12,
+ "end": 24.42
+ },
+ {
+ "word": "It",
+ "start": 24.42,
+ "end": 25.32
+ },
+ {
+ "word": "does",
+ "start": 25.32,
+ "end": 25.48
+ },
+ {
+ "word": "a",
+ "start": 25.48,
+ "end": 25.62
+ },
+ {
+ "word": "pretty",
+ "start": 25.62,
+ "end": 25.88
+ },
+ {
+ "word": "good",
+ "start": 25.88,
+ "end": 26.08
+ },
+ {
+ "word": "job",
+ "start": 26.08,
+ "end": 26.32
+ },
+ {
+ "word": "filtering",
+ "start": 26.32,
+ "end": 26.8
+ },
+ {
+ "word": "noise,",
+ "start": 26.8,
+ "end": 27.18
+ },
+ {
+ "word": "especially",
+ "start": 27.36,
+ "end": 28.0
+ },
+ {
+ "word": "with",
+ "start": 28.0,
+ "end": 28.28
+ },
+ {
+ "word": "models",
+ "start": 28.28,
+ "end": 28.62
+ },
+ {
+ "word": "that",
+ "start": 28.62,
+ "end": 28.94
+ },
+ {
+ "word": "use",
+ "start": 28.94,
+ "end": 29.22
+ },
+ {
+ "word": "voice",
+ "start": 29.22,
+ "end": 29.54
+ },
+ {
+ "word": "active.",
+ "start": 29.54,
+ "end": 29.9
+ }
+]
\ No newline at end of file
diff --git a/audio_tests/generate_transcripts.py b/audio_tests/generate_transcripts.py
new file mode 100644
index 0000000..7eb180f
--- /dev/null
+++ b/audio_tests/generate_transcripts.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+"""Generate word-level timestamped transcripts using faster-whisper (offline).
+
+Produces one JSON file per audio with: [{word, start, end}, ...]
+"""
+
+import json
+import os
+from faster_whisper import WhisperModel
+
+AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
+
+FILES = [
+ ("00_00_07_english_1_speaker.wav", "en"),
+ ("00_00_16_french_1_speaker.wav", "fr"),
+ ("00_00_30_english_3_speakers.wav", "en"),
+]
+
+def main():
+ print("Loading faster-whisper model (base, cpu, float32)...")
+ model = WhisperModel("base", device="cpu", compute_type="float32")
+
+ for filename, lang in FILES:
+ audio_path = os.path.join(AUDIO_DIR, filename)
+ out_path = os.path.join(
+ AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json"
+ )
+
+ print(f"\n{'='*60}")
+ print(f"Transcribing: {filename} (language={lang})")
+ print(f"{'='*60}")
+
+ segments, info = model.transcribe(
+ audio_path, word_timestamps=True, language=lang
+ )
+
+ words = []
+ for segment in segments:
+ if segment.words:
+ for w in segment.words:
+ words.append({
+ "word": w.word.strip(),
+ "start": round(w.start, 3),
+ "end": round(w.end, 3),
+ })
+ print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}")
+
+ with open(out_path, "w", encoding="utf-8") as f:
+ json.dump(words, f, indent=2, ensure_ascii=False)
+
+ print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}")
+
+ print("\nDone.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark_chart.png b/benchmark_chart.png
new file mode 100644
index 0000000..20123bd
Binary files /dev/null and b/benchmark_chart.png differ
diff --git a/benchmark_scatter.png b/benchmark_scatter.png
new file mode 100644
index 0000000..9f62bf3
Binary files /dev/null and b/benchmark_scatter.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 74ade12..9a79780 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
-version = "0.2.18"
+version = "0.2.19"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -42,6 +42,7 @@ dependencies = [
]
[project.optional-dependencies]
+test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"]
@@ -64,6 +65,7 @@ packages = [
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.local_agreement",
+ "whisperlivekit.voxtral_mlx",
"whisperlivekit.silero_vad_models"
]
diff --git a/run_benchmark.py b/run_benchmark.py
new file mode 100644
index 0000000..5a4e23b
--- /dev/null
+++ b/run_benchmark.py
@@ -0,0 +1,291 @@
+#!/usr/bin/env python3
+"""
+Comprehensive benchmark runner for WhisperLiveKit.
+
+Tests all available backend+policy combinations across multiple audio files,
+model sizes, and VAC on/off configurations. Outputs structured JSON that
+is consumed by the report generator.
+
+Usage:
+ python run_benchmark.py # full benchmark
+ python run_benchmark.py --quick # subset (tiny models, fewer combos)
+ python run_benchmark.py --json results.json # custom output path
+"""
+
+import argparse
+import asyncio
+import gc
+import json
+import logging
+import platform
+import subprocess
+import sys
+import time
+from dataclasses import asdict
+from pathlib import Path
+
+logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger("benchmark")
+logger.setLevel(logging.INFO)
+
+# Re-use harness functions
+sys.path.insert(0, str(Path(__file__).parent))
+from test_backend_offline import (
+ AUDIO_TESTS_DIR,
+ SAMPLE_RATE,
+ TestResult,
+ create_engine,
+ discover_audio_files,
+ download_sample_audio,
+ load_audio,
+ run_test,
+)
+
+CACHE_DIR = Path(__file__).parent / ".test_cache"
+
+
+def get_system_info() -> dict:
+ """Collect system metadata for the report."""
+ info = {
+ "platform": platform.platform(),
+ "machine": platform.machine(),
+ "processor": platform.processor(),
+ "python_version": platform.python_version(),
+ }
+
+ # macOS: get chip info
+ try:
+ chip = subprocess.check_output(
+ ["sysctl", "-n", "machdep.cpu.brand_string"], text=True
+ ).strip()
+ info["cpu"] = chip
+ except Exception:
+ info["cpu"] = platform.processor()
+
+ # RAM
+ try:
+ mem_bytes = int(
+ subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()
+ )
+ info["ram_gb"] = round(mem_bytes / (1024**3))
+ except Exception:
+ info["ram_gb"] = None
+
+ # Backend versions
+ versions = {}
+ try:
+ import faster_whisper
+ versions["faster-whisper"] = faster_whisper.__version__
+ except ImportError:
+ pass
+ try:
+ import mlx_whisper # noqa: F401
+ versions["mlx-whisper"] = "installed"
+ except ImportError:
+ pass
+ try:
+ import mlx.core as mx
+ versions["mlx"] = mx.__version__
+ except ImportError:
+ pass
+ try:
+ import transformers
+ versions["transformers"] = transformers.__version__
+ except ImportError:
+ pass
+ try:
+ import torch
+ versions["torch"] = torch.__version__
+ except ImportError:
+ pass
+
+ info["backend_versions"] = versions
+ return info
+
+
+def detect_combos(quick: bool = False) -> list:
+ """Build list of (backend, policy, model_size) combos to test."""
+ combos = []
+
+ # Model sizes to test
+ model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"]
+
+ # faster-whisper
+ try:
+ import faster_whisper # noqa: F401
+ for model in model_sizes:
+ combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model})
+ combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model})
+ except ImportError:
+ pass
+
+ # mlx-whisper
+ try:
+ import mlx_whisper # noqa: F401
+ for model in model_sizes:
+ combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model})
+ combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model})
+ except ImportError:
+ pass
+
+ # voxtral-mlx (single model, single policy)
+ try:
+ from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
+ combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""})
+ except ImportError:
+ pass
+
+ # voxtral HF (single model, single policy)
+ try:
+ from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
+ combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""})
+ except ImportError:
+ pass
+
+ return combos
+
+
+def collect_audio_files() -> list:
+ """Collect all benchmark audio files."""
+ files = []
+
+ # audio_tests/ directory
+ if AUDIO_TESTS_DIR.is_dir():
+ files.extend(discover_audio_files(str(AUDIO_TESTS_DIR)))
+
+ # JFK sample
+ jfk = CACHE_DIR / "jfk.wav"
+ if not jfk.exists():
+ jfk = download_sample_audio()
+ if jfk.exists():
+ files.append(jfk)
+
+ return files
+
+
+async def run_single_combo(
+ combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float,
+) -> list:
+ """Run one backend+policy+model combo across all audio files."""
+ backend = combo["backend"]
+ policy = combo["policy"]
+ model = combo["model"]
+
+ results = []
+ try:
+ engine = create_engine(
+ backend=backend,
+ model_size=model,
+ lan=lan,
+ vac=vac,
+ policy=policy,
+ )
+
+ # Quiet noisy loggers
+ for mod in (
+ "whisperlivekit.audio_processor",
+ "whisperlivekit.simul_whisper",
+ "whisperlivekit.tokens_alignment",
+ "whisperlivekit.simul_whisper.align_att_base",
+ "whisperlivekit.simul_whisper.simul_whisper",
+ ):
+ logging.getLogger(mod).setLevel(logging.WARNING)
+
+ for audio_path in audio_files:
+ duration = len(load_audio(str(audio_path))) / SAMPLE_RATE
+ if duration > max_duration:
+ logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)")
+ continue
+
+ file_lan = lan
+ if "french" in audio_path.name.lower() and lan == "en":
+ file_lan = "fr"
+
+ audio = load_audio(str(audio_path))
+ result = await run_test(
+ engine, audio, chunk_ms=100, realtime=False,
+ audio_file=audio_path.name, backend=backend,
+ policy=policy, lan=file_lan,
+ )
+ # Tag with extra metadata
+ result_dict = asdict(result)
+ result_dict["model_size"] = model
+ result_dict["vac"] = vac
+ results.append(result_dict)
+
+ except Exception as e:
+ logger.error(f" FAILED: {e}")
+ import traceback
+ traceback.print_exc()
+
+ return results
+
+
+async def run_full_benchmark(combos, audio_files, max_duration=60.0):
+ """Run all combos with VAC on and off."""
+ all_results = []
+ total = len(combos) * 2 # x2 for VAC on/off
+ idx = 0
+
+ for combo in combos:
+ for vac in [True, False]:
+ idx += 1
+ vac_str = "VAC=on" if vac else "VAC=off"
+ desc = f"{combo['backend']} / {combo['policy']}"
+ if combo["model"]:
+ desc += f" / {combo['model']}"
+ desc += f" / {vac_str}"
+
+ print(f"\n{'='*70}")
+ print(f"[{idx}/{total}] {desc}")
+ print(f"{'='*70}")
+
+ results = await run_single_combo(
+ combo, audio_files, vac=vac, lan="en", max_duration=max_duration,
+ )
+ all_results.extend(results)
+
+ # Free memory between combos
+ gc.collect()
+
+ return all_results
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark")
+ parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos")
+ parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path")
+ parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds")
+ args = parser.parse_args()
+
+ system_info = get_system_info()
+ combos = detect_combos(quick=args.quick)
+ audio_files = collect_audio_files()
+
+ print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM")
+ print(f"Backends: {list(system_info['backend_versions'].keys())}")
+ print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}")
+ print(f"Audio files: {[f.name for f in audio_files]}")
+ print()
+
+ t0 = time.time()
+ all_results = asyncio.run(
+ run_full_benchmark(combos, audio_files, max_duration=args.max_duration)
+ )
+ total_time = time.time() - t0
+
+ output = {
+ "system_info": system_info,
+ "benchmark_date": time.strftime("%Y-%m-%d %H:%M"),
+ "total_benchmark_time_s": round(total_time, 1),
+ "n_combos": len(combos) * 2,
+ "n_audio_files": len(audio_files),
+ "results": all_results,
+ }
+
+ Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
+ print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test_backend_offline.py b/test_backend_offline.py
new file mode 100644
index 0000000..486b715
--- /dev/null
+++ b/test_backend_offline.py
@@ -0,0 +1,783 @@
+#!/usr/bin/env python3
+"""
+Offline test harness and benchmark suite for WhisperLiveKit backends.
+
+Simulates a client-server session by feeding audio files as PCM bytes through
+the full AudioProcessor pipeline (the same path used by the WebSocket server),
+without needing a browser or microphone.
+
+Computes WER (Word Error Rate) and timestamp accuracy when ground truth
+transcript files (.transcript.json) are available alongside audio files.
+
+Usage:
+ # Test with a single audio file:
+ python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav
+
+ # Test all files in audio_tests/:
+ python test_backend_offline.py --backend faster-whisper --no-realtime
+
+ # Override streaming policy:
+ python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime
+
+ # Multi-backend benchmark (auto-detects all installed backends):
+ python test_backend_offline.py --benchmark --no-realtime
+
+ # Export results as JSON:
+ python test_backend_offline.py --benchmark --no-realtime --json results.json
+
+ # Insert silence for testing silence handling:
+ python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0
+"""
+
+import argparse
+import asyncio
+import json
+import logging
+import sys
+import time
+import urllib.request
+from pathlib import Path
+from dataclasses import dataclass, asdict, field
+from typing import List, Optional
+
+import numpy as np
+
+logging.basicConfig(
+ level=logging.WARNING,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+)
+logger = logging.getLogger("test_offline")
+logger.setLevel(logging.INFO)
+
+SAMPLE_RATE = 16000
+JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
+CACHE_DIR = Path(__file__).parent / ".test_cache"
+AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests"
+AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
+
+
+@dataclass
+class WordTimestamp:
+ """Word with its start/end time."""
+ word: str
+ start: float
+ end: float
+
+
+@dataclass
+class TestResult:
+ """Structured result from a single test run."""
+ audio_file: str
+ audio_duration_s: float
+ backend: str
+ policy: str
+ language: str
+ chunk_ms: int
+ realtime_pacing: bool
+ # Timing
+ processing_time_s: float
+ rtf: float # real-time factor
+ # Transcription output
+ transcription: str
+ n_lines: int
+ n_responses: int
+ # WER metrics (None if no ground truth)
+ wer: Optional[float] = None
+ wer_details: Optional[dict] = None
+ # Timestamp accuracy (None if no ground truth)
+ timestamp_mae: Optional[float] = None
+ timestamp_max_delta: Optional[float] = None
+ timestamp_median_delta: Optional[float] = None
+ # Word-level timestamps
+ word_timestamps: List[WordTimestamp] = field(default_factory=list)
+ # Raw last response
+ last_response: Optional[dict] = None
+
+
+def download_sample_audio() -> Path:
+ """Download the jfk.wav sample if not cached."""
+ CACHE_DIR.mkdir(exist_ok=True)
+ path = CACHE_DIR / "jfk.wav"
+ if not path.exists():
+ logger.info(f"Downloading sample audio to {path} ...")
+ urllib.request.urlretrieve(JFK_WAV_URL, path)
+ logger.info("Done.")
+ return path
+
+
+def load_audio(path: str) -> np.ndarray:
+ """Load audio file as float32 mono 16kHz numpy array.
+
+ Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa).
+ """
+ ext = Path(path).suffix.lower()
+ if ext in (".mp3", ".ogg", ".m4a"):
+ import librosa
+ audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
+ return audio.astype(np.float32)
+
+ import soundfile as sf
+ audio, sr = sf.read(path, dtype="float32")
+ if audio.ndim > 1:
+ audio = audio.mean(axis=1)
+ if sr != SAMPLE_RATE:
+ import librosa
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
+ return audio
+
+
+def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray:
+ """Insert silence into audio at a given position.
+
+ Args:
+ audio: Float32 mono audio array at SAMPLE_RATE.
+ silence_sec: Duration of silence to insert in seconds.
+ position_sec: Position in seconds where silence starts.
+ Returns:
+ New audio array with silence inserted.
+ """
+ pos_samples = int(position_sec * SAMPLE_RATE)
+ silence_samples = int(silence_sec * SAMPLE_RATE)
+ pos_samples = min(pos_samples, len(audio))
+ silence = np.zeros(silence_samples, dtype=np.float32)
+ return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]])
+
+
+def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
+ """Convert float32 audio to s16le PCM bytes (what the browser sends)."""
+ return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
+
+
+def create_engine(
+ backend: str, model_size: str, lan: str,
+ diarization: bool = False, vac: bool = True, policy: str = "",
+):
+ """Create a TranscriptionEngine with the given backend config."""
+ import gc
+ from whisperlivekit.core import TranscriptionEngine
+
+ # Reset singleton so we get a fresh instance
+ TranscriptionEngine._instance = None
+ TranscriptionEngine._initialized = False
+ gc.collect()
+
+ kwargs = dict(
+ backend=backend,
+ lan=lan,
+ pcm_input=True,
+ vac=vac,
+ transcription=True,
+ diarization=diarization,
+ )
+ if model_size:
+ kwargs["model_size"] = model_size
+ if policy:
+ kwargs["backend_policy"] = policy
+
+ return TranscriptionEngine(**kwargs)
+
+
+def _extract_text_from_response(response_dict: dict) -> str:
+ """Extract full transcription text from a FrontData dict."""
+ segments = response_dict.get("lines", [])
+ full_text = " ".join(
+ seg.get("text", "").strip()
+ for seg in segments
+ if seg.get("text", "").strip()
+ )
+ buf = response_dict.get("buffer_transcription", "").strip()
+ if buf:
+ full_text = f"{full_text} {buf}".strip() if full_text else buf
+ return full_text
+
+
+async def run_test(
+ engine, audio: np.ndarray, chunk_ms: int, realtime: bool,
+ audio_file: str = "", backend: str = "", policy: str = "", lan: str = "",
+) -> TestResult:
+ """
+ Simulate a client session through the full AudioProcessor pipeline.
+
+ 1. Create AudioProcessor (one per "client session")
+ 2. Start async pipeline (transcription_processor, results_formatter, etc.)
+ 3. Feed audio as PCM bytes in timed chunks
+ 4. Collect and display FrontData responses
+ 5. Signal EOF and cleanup
+ """
+ from whisperlivekit.audio_processor import AudioProcessor
+
+ chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000)
+ total_samples = len(audio)
+ audio_duration = total_samples / SAMPLE_RATE
+
+ logger.info(
+ f"Audio: {audio_duration:.2f}s | "
+ f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | "
+ f"Steps: {total_samples // chunk_samples + 1} | "
+ f"Realtime: {realtime}"
+ )
+
+ # --- Server side: create processor and start pipeline ---
+ processor = AudioProcessor(transcription_engine=engine)
+ results_generator = await processor.create_tasks()
+
+ # Collect results in background (like handle_websocket_results)
+ all_responses = []
+ response_count = 0
+ last_printed_text = ""
+
+ async def collect_results():
+ nonlocal response_count, last_printed_text
+ async for response in results_generator:
+ all_responses.append(response)
+ response_count += 1
+ d = response.to_dict()
+
+ # Only print when transcription text actually changes
+ current_text = _extract_text_from_response(d)
+ if current_text and current_text != last_printed_text:
+ buf = d.get("buffer_transcription", "").strip()
+ committed = current_text
+ if buf and committed.endswith(buf):
+ committed = committed[:-len(buf)].strip()
+
+ # Show committed text + buffer separately
+ display = committed
+ if buf:
+ display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m"
+ print(f" > {display}", flush=True)
+ last_printed_text = current_text
+
+ result_task = asyncio.create_task(collect_results())
+
+ # --- Client side: feed audio as PCM bytes ---
+ t_start = time.time()
+
+ for offset in range(0, total_samples, chunk_samples):
+ chunk = audio[offset : offset + chunk_samples]
+ pcm_bytes = float32_to_s16le_bytes(chunk)
+ await processor.process_audio(pcm_bytes)
+ if realtime:
+ await asyncio.sleep(chunk_ms / 1000)
+
+ feed_elapsed = time.time() - t_start
+
+ logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...")
+
+ # Signal end of audio (like client disconnect / empty message)
+ await processor.process_audio(None)
+
+ # Wait for pipeline to drain completely
+ try:
+ await asyncio.wait_for(result_task, timeout=120.0)
+ except asyncio.TimeoutError:
+ logger.warning("Timed out waiting for results. Proceeding with cleanup.")
+ result_task.cancel()
+ try:
+ await result_task
+ except asyncio.CancelledError:
+ pass
+
+ # --- Capture word-level timestamps before cleanup ---
+ word_timestamps = []
+ try:
+ state = await processor.get_current_state()
+ for token in state.tokens:
+ if hasattr(token, 'start') and hasattr(token, 'text') and token.text:
+ word_timestamps.append(WordTimestamp(
+ word=token.text.strip(),
+ start=round(token.start, 3),
+ end=round(token.end, 3),
+ ))
+ except Exception as e:
+ logger.warning(f"Could not capture word timestamps: {e}")
+
+ # Cleanup
+ await processor.cleanup()
+
+ total_elapsed = time.time() - t_start
+
+ # --- Build result ---
+ transcription = ""
+ n_lines = 0
+ last_response_dict = None
+
+ if all_responses:
+ last = all_responses[-1].to_dict()
+ last_response_dict = last
+ n_lines = len(last.get("lines", []))
+ transcription = _extract_text_from_response(last)
+
+ # --- Compute WER and timestamp accuracy against ground truth ---
+ from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
+
+ wer_val = None
+ wer_details = None
+ ts_mae = None
+ ts_max_delta = None
+ ts_median_delta = None
+
+ gt_path = Path(audio_file).with_suffix(".transcript.json")
+ if not gt_path.exists():
+ gt_path = AUDIO_TESTS_DIR / gt_path
+ gt = None
+ if gt_path.exists():
+ with open(gt_path) as f:
+ gt = json.load(f)
+
+ # WER
+ gt_text = " ".join(w["word"] for w in gt)
+ wer_result = compute_wer(gt_text, transcription)
+ wer_val = round(wer_result["wer"], 4)
+ wer_details = wer_result
+
+ # Timestamp accuracy
+ if word_timestamps:
+ pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps]
+ ts_result = compute_timestamp_accuracy(pred_dicts, gt)
+ ts_mae = ts_result["mae_start"]
+ ts_max_delta = ts_result["max_delta_start"]
+ ts_median_delta = ts_result["median_delta_start"]
+
+ result = TestResult(
+ audio_file=audio_file,
+ audio_duration_s=round(audio_duration, 2),
+ backend=backend,
+ policy=policy,
+ language=lan,
+ chunk_ms=chunk_ms,
+ realtime_pacing=realtime,
+ processing_time_s=round(total_elapsed, 2),
+ rtf=round(total_elapsed / audio_duration, 2),
+ transcription=transcription,
+ n_lines=n_lines,
+ n_responses=response_count,
+ wer=wer_val,
+ wer_details=wer_details,
+ timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None,
+ timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None,
+ timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None,
+ word_timestamps=word_timestamps,
+ last_response=last_response_dict,
+ )
+
+ # --- Print summary ---
+ print(f"\n{'=' * 60}")
+ print(f"RESULT: {audio_file}")
+ print(f"{'=' * 60}")
+ print(f"Transcription: {transcription}")
+ print(f"Lines: {n_lines} | Responses: {response_count}")
+ print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x")
+
+ if wer_val is not None:
+ print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})")
+
+ # Print word timestamps if available
+ if word_timestamps:
+ print(f"\nWord timestamps ({len(word_timestamps)} words):")
+ for wt in word_timestamps:
+ print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}")
+
+ # Detailed comparison with ground truth
+ if gt:
+ print(f"\n vs Ground truth ({len(gt)} words):")
+ max_words = max(len(word_timestamps), len(gt))
+ for i in range(max_words):
+ pred = word_timestamps[i] if i < len(word_timestamps) else None
+ ref = gt[i] if i < len(gt) else None
+ p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30
+ r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else ""
+ delta = ""
+ if pred and ref:
+ d = pred.start - ref['start']
+ delta = f" Δstart={d:+.2f}"
+ print(f" {p_str} | {r_str}{delta}")
+
+ if ts_mae is not None:
+ print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s")
+
+ print(f"{'=' * 60}")
+
+ return result
+
+
+def discover_audio_files(directory: str) -> List[Path]:
+ """Find all supported audio files in directory."""
+ d = Path(directory)
+ files = sorted(
+ p for p in d.iterdir()
+ if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
+ )
+ return files
+
+
+async def run_all_tests(
+ engine, audio_files: List[Path], chunk_ms: int, realtime: bool,
+ backend: str, policy: str, lan: str, max_duration: float = 60.0,
+ silence_insertions: Optional[List[List[float]]] = None,
+) -> List[TestResult]:
+ """Run tests on multiple audio files sequentially."""
+ results = []
+ for audio_path in audio_files:
+ # Detect language from filename if "french" in name
+ file_lan = lan
+ if "french" in audio_path.name.lower() and lan == "en":
+ file_lan = "fr"
+ logger.info(f"Auto-detected language 'fr' from filename")
+
+ audio = load_audio(str(audio_path))
+
+ # Insert silence segments (applied in reverse position order to keep offsets valid)
+ if silence_insertions:
+ for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True):
+ logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s")
+ audio = insert_silence(audio, secs, at_sec)
+
+ duration = len(audio) / SAMPLE_RATE
+
+ if duration > max_duration:
+ logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)")
+ continue
+
+ print(f"\n{'#' * 60}")
+ print(f"# Testing: {audio_path.name} ({duration:.1f}s)")
+ print(f"{'#' * 60}")
+
+ result = await run_test(
+ engine, audio, chunk_ms, realtime,
+ audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan,
+ )
+ results.append(result)
+
+ return results
+
+
+def print_benchmark_summary(results: List[TestResult]):
+ """Print a tabular summary of all test results."""
+ print(f"\n{'=' * 110}")
+ print("BENCHMARK SUMMARY")
+ print(f"{'=' * 110}")
+ print(
+ f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} "
+ f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}"
+ )
+ print(f"{'-' * 110}")
+ for r in results:
+ wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
+ mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
+ print(
+ f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s "
+ f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}"
+ )
+ print(f"{'-' * 110}")
+ total_audio = sum(r.audio_duration_s for r in results)
+ total_time = sum(r.processing_time_s for r in results)
+ avg_rtf = total_time / total_audio if total_audio > 0 else 0
+ wer_vals = [r.wer for r in results if r.wer is not None]
+ avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
+ mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None]
+ avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
+ print(
+ f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s "
+ f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}"
+ )
+ print(f"{'=' * 110}")
+
+ # Print transcription excerpts
+ print(f"\nTRANSCRIPTIONS:")
+ print(f"{'-' * 110}")
+ for r in results:
+ excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
+ print(f" {r.audio_file}:")
+ print(f" {excerpt}")
+ print(f"{'=' * 110}")
+
+
+def detect_available_backends() -> List[dict]:
+ """Probe which backends can be imported and return (backend, policy) combos.
+
+ Returns list of dicts with keys: backend, policy, description.
+ """
+ combos = []
+
+ # faster-whisper
+ try:
+ import faster_whisper # noqa: F401
+ combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"})
+ combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"})
+ except ImportError:
+ pass
+
+ # mlx-whisper (macOS only)
+ try:
+ import mlx_whisper # noqa: F401
+ combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"})
+ combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"})
+ except ImportError:
+ pass
+
+ # openai-whisper
+ try:
+ import whisper # noqa: F401
+ combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"})
+ combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"})
+ except ImportError:
+ pass
+
+ # voxtral-mlx
+ try:
+ from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
+ combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"})
+ except ImportError:
+ pass
+
+ # voxtral (HuggingFace)
+ try:
+ from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
+ combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"})
+ except ImportError:
+ pass
+
+ return combos
+
+
+def print_cross_backend_comparison(all_results: List[TestResult]):
+ """Print a comparison table across backends and policies."""
+ print(f"\n{'=' * 110}")
+ print("CROSS-BACKEND BENCHMARK COMPARISON")
+ print(f"{'=' * 110}")
+ print(
+ f"{'Backend':<18} {'Policy':<16} {'File':<30} "
+ f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}"
+ )
+ print(f"{'-' * 110}")
+
+ for r in all_results:
+ wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
+ rtf_str = f"{r.rtf:.2f}x"
+ mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
+ max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -"
+ # Truncate filename for readability
+ fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file
+ print(
+ f"{r.backend:<18} {r.policy:<16} {fname:<30} "
+ f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}"
+ )
+
+ print(f"{'-' * 110}")
+
+ # Per-backend averages
+ from collections import defaultdict
+ by_combo = defaultdict(list)
+ for r in all_results:
+ by_combo[(r.backend, r.policy)].append(r)
+
+ print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}")
+ print(f"{'-' * 80}")
+ for (backend, policy), group in sorted(by_combo.items()):
+ wer_vals = [r.wer for r in group if r.wer is not None]
+ rtf_vals = [r.rtf for r in group]
+ mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None]
+ avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
+ avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x"
+ avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
+ print(
+ f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}"
+ )
+ print(f"{'=' * 110}")
+
+
+def _quiet_loggers(verbose: bool):
+ """Set internal module log levels to reduce noise."""
+ if verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+ else:
+ for mod in (
+ "whisperlivekit.audio_processor", "whisperlivekit.simul_whisper",
+ "whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base",
+ "whisperlivekit.simul_whisper.simul_whisper",
+ ):
+ logging.getLogger(mod).setLevel(logging.WARNING)
+
+
+async def run_benchmark(
+ audio_files: List[Path], chunk_ms: int, realtime: bool,
+ model_size: str, lan: str, max_duration: float, vac: bool,
+ verbose: bool,
+) -> List[TestResult]:
+ """Run benchmark across all available backend+policy combinations."""
+ combos = detect_available_backends()
+ if not combos:
+ logger.error("No backends available. Install at least one ASR backend.")
+ return []
+
+ logger.info(f"Detected {len(combos)} backend+policy combinations:")
+ for c in combos:
+ logger.info(f" - {c['description']}")
+
+ all_results = []
+ for i, combo in enumerate(combos, 1):
+ backend = combo["backend"]
+ policy = combo["policy"]
+ desc = combo["description"]
+
+ print(f"\n{'*' * 70}")
+ print(f"* BENCHMARK {i}/{len(combos)}: {desc}")
+ print(f"{'*' * 70}")
+
+ try:
+ engine = create_engine(
+ backend, model_size, lan, vac=vac, policy=policy,
+ )
+ _quiet_loggers(verbose)
+
+ results = await run_all_tests(
+ engine, audio_files, chunk_ms, realtime,
+ backend=backend, policy=policy, lan=lan,
+ max_duration=max_duration,
+ )
+ all_results.extend(results)
+ except Exception as e:
+ logger.error(f"Failed to run {desc}: {e}")
+ import traceback
+ traceback.print_exc()
+
+ return all_results
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Offline backend test harness (AudioProcessor-level)"
+ )
+ parser.add_argument(
+ "--backend", default="faster-whisper",
+ help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.",
+ )
+ parser.add_argument(
+ "--policy", default="",
+ help="Override backend policy: localagreement, simulstreaming, voxtral.",
+ )
+ parser.add_argument(
+ "--audio", default=None,
+ help="Path to a single audio file (WAV, MP3, FLAC, etc.).",
+ )
+ parser.add_argument(
+ "--audio-dir", default=None,
+ help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.",
+ )
+ parser.add_argument(
+ "--chunk-ms", type=int, default=100,
+ help="Chunk size in milliseconds (simulates real-time interval).",
+ )
+ parser.add_argument(
+ "--model", default="", dest="model_size",
+ help="Model size or HF repo ID.",
+ )
+ parser.add_argument("--lan", default="en", help="Language code.")
+ parser.add_argument(
+ "--no-realtime", action="store_true",
+ help="Skip real-time pacing between chunks (faster but less realistic).",
+ )
+ parser.add_argument(
+ "--no-vac", action="store_true",
+ help="Disable Voice Activity Classification (send all audio without silence filtering).",
+ )
+ parser.add_argument(
+ "--diarization", action="store_true",
+ help="Enable speaker diarization.",
+ )
+ parser.add_argument(
+ "--benchmark", action="store_true",
+ help="Run benchmark across all detected backend+policy combinations.",
+ )
+ parser.add_argument(
+ "--json", default=None, dest="json_output",
+ help="Write structured JSON results to this file.",
+ )
+ parser.add_argument(
+ "--max-duration", type=float, default=60.0,
+ help="Skip audio files longer than this many seconds (default: 60).",
+ )
+ parser.add_argument(
+ "--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"),
+ action="append", default=[],
+ help="Insert SECS of silence at AT_SEC position. Can be repeated. "
+ "E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="store_true",
+ help="Show debug-level logs from all components.",
+ )
+ args = parser.parse_args()
+
+ realtime = not args.no_realtime
+ vac = not args.no_vac
+
+ # Resolve audio file(s)
+ if args.audio:
+ audio_files = [Path(args.audio)]
+ elif args.audio_dir:
+ audio_files = discover_audio_files(args.audio_dir)
+ elif AUDIO_TESTS_DIR.is_dir():
+ audio_files = discover_audio_files(str(AUDIO_TESTS_DIR))
+ else:
+ # Fall back to jfk.wav download
+ audio_files = [download_sample_audio()]
+
+ if not audio_files:
+ logger.error("No audio files found.")
+ sys.exit(1)
+
+ logger.info(f"Audio files: {[f.name for f in audio_files]}")
+
+ if args.benchmark:
+ # --- Multi-backend benchmark mode ---
+ all_results = asyncio.run(
+ run_benchmark(
+ audio_files, args.chunk_ms, realtime,
+ args.model_size, args.lan, args.max_duration, vac,
+ args.verbose,
+ )
+ )
+ if all_results:
+ print_cross_backend_comparison(all_results)
+ results = all_results
+ else:
+ # --- Single-backend mode ---
+ policy = args.policy
+ logger.info(f"Creating {args.backend} engine...")
+ engine = create_engine(
+ args.backend, args.model_size, args.lan,
+ diarization=args.diarization, vac=vac, policy=policy,
+ )
+ logger.info("Engine ready.")
+
+ _quiet_loggers(args.verbose)
+
+ results = asyncio.run(
+ run_all_tests(
+ engine, audio_files, args.chunk_ms, realtime,
+ args.backend, policy, args.lan,
+ max_duration=args.max_duration,
+ silence_insertions=args.insert_silence or None,
+ )
+ )
+
+ if len(results) > 1:
+ print_benchmark_summary(results)
+
+ # JSON output
+ if args.json_output and results:
+ json_results = []
+ for r in results:
+ d = asdict(r)
+ d.pop("last_response", None) # too verbose for summary
+ json_results.append(d)
+ Path(args.json_output).write_text(
+ json.dumps(json_results, indent=2, ensure_ascii=False)
+ )
+ logger.info(f"Results written to {args.json_output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..1a26f33
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,58 @@
+"""Shared pytest fixtures for WhisperLiveKit tests."""
+
+import json
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+from whisperlivekit.timed_objects import ASRToken, Silence, Transcript
+
+
+AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests"
+
+
+@pytest.fixture
+def sample_tokens():
+ """A short sequence of ASRToken objects."""
+ return [
+ ASRToken(start=0.0, end=0.5, text="Hello"),
+ ASRToken(start=0.5, end=1.0, text=" world"),
+ ASRToken(start=1.0, end=1.5, text=" test."),
+ ]
+
+
+@pytest.fixture
+def sample_silence():
+ """A completed silence event."""
+ s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True)
+ s.compute_duration()
+ return s
+
+
+@pytest.fixture
+def mock_args():
+ """Minimal args namespace for AudioProcessor tests."""
+ return SimpleNamespace(
+ diarization=False,
+ transcription=True,
+ target_language="",
+ vac=False,
+ vac_chunk_size=0.04,
+ min_chunk_size=0.1,
+ pcm_input=True,
+ punctuation_split=False,
+ backend="faster-whisper",
+ backend_policy="localagreement",
+ vad=True,
+ )
+
+
+@pytest.fixture
+def ground_truth_en():
+ """Ground truth transcript for the 7s English audio (if available)."""
+ path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json"
+ if path.exists():
+ with open(path) as f:
+ return json.load(f)
+ return None
diff --git a/tests/test_audio_processor.py b/tests/test_audio_processor.py
new file mode 100644
index 0000000..9286108
--- /dev/null
+++ b/tests/test_audio_processor.py
@@ -0,0 +1,209 @@
+"""Tests for AudioProcessor pipeline with mocked ASR backends.
+
+These tests verify the async audio processing pipeline works correctly
+without requiring any real ASR models to be loaded.
+"""
+
+import asyncio
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+
+from whisperlivekit.timed_objects import ASRToken, Transcript
+
+
+# ---------------------------------------------------------------------------
+# Mock ASR components
+# ---------------------------------------------------------------------------
+
+class MockASR:
+ """Mock ASR model holder."""
+ sep = " "
+ SAMPLING_RATE = 16000
+
+ def __init__(self):
+ self.transcribe_kargs = {}
+ self.original_language = "en"
+ self.backend_choice = "mock"
+
+ def transcribe(self, audio):
+ return None
+
+
+class MockOnlineProcessor:
+ """Mock online processor that returns canned tokens."""
+ SAMPLING_RATE = 16000
+
+ def __init__(self, asr=None):
+ self.asr = asr or MockASR()
+ self.audio_buffer = np.array([], dtype=np.float32)
+ self.end = 0.0
+ self._call_count = 0
+ self._finished = False
+
+ def insert_audio_chunk(self, audio, audio_stream_end_time):
+ self.audio_buffer = np.append(self.audio_buffer, audio)
+ self.end = audio_stream_end_time
+
+ def process_iter(self, is_last=False):
+ self._call_count += 1
+ # Emit a token on every call when we have audio
+ if len(self.audio_buffer) > 0:
+ t = self._call_count * 0.5
+ return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end
+ return [], self.end
+
+ def get_buffer(self):
+ return Transcript(start=None, end=None, text="")
+
+ def start_silence(self):
+ return [], self.end
+
+ def end_silence(self, silence_duration, offset):
+ pass
+
+ def new_speaker(self, change_speaker):
+ pass
+
+ def finish(self):
+ self._finished = True
+ return [], self.end
+
+ def warmup(self, audio, init_prompt=""):
+ pass
+
+
+def _make_pcm_bytes(duration_s=0.1, sample_rate=16000):
+ """Generate silent PCM s16le bytes."""
+ n_samples = int(duration_s * sample_rate)
+ audio = np.zeros(n_samples, dtype=np.float32)
+ return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+@pytest.fixture
+def mock_engine():
+ """Create a mock TranscriptionEngine-like object."""
+ engine = SimpleNamespace(
+ asr=MockASR(),
+ diarization_model=None,
+ translation_model=None,
+ args=SimpleNamespace(
+ diarization=False,
+ transcription=True,
+ target_language="",
+ vac=False,
+ vac_chunk_size=0.04,
+ min_chunk_size=0.1,
+ pcm_input=True,
+ punctuation_split=False,
+ backend="mock",
+ backend_policy="localagreement",
+ vad=True,
+ model_size="base",
+ lan="en",
+ ),
+ )
+ return engine
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+class TestPCMConversion:
+ """Test PCM byte conversion without needing the full pipeline."""
+
+ def test_s16le_roundtrip(self):
+ """Convert float32 → s16le → float32 and verify approximate roundtrip."""
+ original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32)
+ s16 = (original * 32768).clip(-32768, 32767).astype(np.int16)
+ pcm_bytes = s16.tobytes()
+ # Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float)
+ recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
+
+ np.testing.assert_allclose(recovered, original, atol=1 / 32768)
+
+
+@pytest.mark.asyncio
+class TestPipelineBasics:
+ async def test_feed_audio_and_get_responses(self, mock_engine):
+ """Feed audio through the pipeline and verify we get responses."""
+ from whisperlivekit.audio_processor import AudioProcessor
+
+ with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
+ processor = AudioProcessor(transcription_engine=mock_engine)
+ results_gen = await processor.create_tasks()
+
+ responses = []
+
+ async def collect():
+ async for resp in results_gen:
+ responses.append(resp)
+
+ task = asyncio.create_task(collect())
+
+ # Feed 2 seconds of audio in 100ms chunks
+ for _ in range(20):
+ await processor.process_audio(_make_pcm_bytes(0.1))
+
+ # Signal EOF
+ await processor.process_audio(None)
+
+ await asyncio.wait_for(task, timeout=10.0)
+ await processor.cleanup()
+
+ # We should have gotten at least one response
+ assert len(responses) > 0
+
+ async def test_eof_terminates_pipeline(self, mock_engine):
+ """Sending None (EOF) should cleanly terminate the pipeline."""
+ from whisperlivekit.audio_processor import AudioProcessor
+
+ with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
+ processor = AudioProcessor(transcription_engine=mock_engine)
+ results_gen = await processor.create_tasks()
+
+ responses = []
+
+ async def collect():
+ async for resp in results_gen:
+ responses.append(resp)
+
+ task = asyncio.create_task(collect())
+
+ # Send a small amount of audio then EOF
+ await processor.process_audio(_make_pcm_bytes(0.5))
+ await processor.process_audio(None)
+
+ await asyncio.wait_for(task, timeout=10.0)
+ await processor.cleanup()
+
+ # Pipeline should have terminated without error
+ assert task.done()
+
+ async def test_empty_audio_no_crash(self, mock_engine):
+ """Sending EOF immediately (no audio) should not crash."""
+ from whisperlivekit.audio_processor import AudioProcessor
+
+ with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
+ processor = AudioProcessor(transcription_engine=mock_engine)
+ results_gen = await processor.create_tasks()
+
+ responses = []
+
+ async def collect():
+ async for resp in results_gen:
+ responses.append(resp)
+
+ task = asyncio.create_task(collect())
+ await processor.process_audio(None)
+
+ await asyncio.wait_for(task, timeout=10.0)
+ await processor.cleanup()
+ assert task.done()
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..23f4c56
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,99 @@
+"""Tests for WhisperLiveKitConfig."""
+
+import logging
+from types import SimpleNamespace
+
+import pytest
+
+from whisperlivekit.config import WhisperLiveKitConfig
+
+
+class TestDefaults:
+ def test_default_backend(self):
+ c = WhisperLiveKitConfig()
+ assert c.backend == "auto"
+
+ def test_default_policy(self):
+ c = WhisperLiveKitConfig()
+ assert c.backend_policy == "simulstreaming"
+
+ def test_default_language(self):
+ c = WhisperLiveKitConfig()
+ assert c.lan == "auto"
+
+ def test_default_vac(self):
+ c = WhisperLiveKitConfig()
+ assert c.vac is True
+
+ def test_default_model_size(self):
+ c = WhisperLiveKitConfig()
+ assert c.model_size == "base"
+
+ def test_default_transcription(self):
+ c = WhisperLiveKitConfig()
+ assert c.transcription is True
+ assert c.diarization is False
+
+
+class TestPostInit:
+ def test_en_model_forces_english(self):
+ c = WhisperLiveKitConfig(model_size="tiny.en")
+ assert c.lan == "en"
+
+ def test_en_suffix_with_auto_language(self):
+ c = WhisperLiveKitConfig(model_size="base.en", lan="auto")
+ assert c.lan == "en"
+
+ def test_non_en_model_keeps_language(self):
+ c = WhisperLiveKitConfig(model_size="base", lan="fr")
+ assert c.lan == "fr"
+
+ def test_policy_alias_1(self):
+ c = WhisperLiveKitConfig(backend_policy="1")
+ assert c.backend_policy == "simulstreaming"
+
+ def test_policy_alias_2(self):
+ c = WhisperLiveKitConfig(backend_policy="2")
+ assert c.backend_policy == "localagreement"
+
+ def test_policy_no_alias(self):
+ c = WhisperLiveKitConfig(backend_policy="localagreement")
+ assert c.backend_policy == "localagreement"
+
+
+class TestFromNamespace:
+ def test_known_keys(self):
+ ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3")
+ c = WhisperLiveKitConfig.from_namespace(ns)
+ assert c.backend == "faster-whisper"
+ assert c.lan == "en"
+ assert c.model_size == "large-v3"
+
+ def test_ignores_unknown_keys(self):
+ ns = SimpleNamespace(backend="auto", unknown_key="value", another="x")
+ c = WhisperLiveKitConfig.from_namespace(ns)
+ assert c.backend == "auto"
+ assert not hasattr(c, "unknown_key")
+
+ def test_preserves_defaults_for_missing(self):
+ ns = SimpleNamespace(backend="voxtral-mlx")
+ c = WhisperLiveKitConfig.from_namespace(ns)
+ assert c.lan == "auto"
+ assert c.vac is True
+
+
+class TestFromKwargs:
+ def test_known_keys(self):
+ c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr")
+ assert c.backend == "mlx-whisper"
+ assert c.lan == "fr"
+
+ def test_warns_on_unknown_keys(self, caplog):
+ with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"):
+ c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value")
+ assert c.backend == "auto"
+ assert "bogus" in caplog.text
+
+ def test_post_init_runs(self):
+ c = WhisperLiveKitConfig.from_kwargs(model_size="small.en")
+ assert c.lan == "en"
diff --git a/tests/test_hypothesis_buffer.py b/tests/test_hypothesis_buffer.py
new file mode 100644
index 0000000..732090a
--- /dev/null
+++ b/tests/test_hypothesis_buffer.py
@@ -0,0 +1,172 @@
+"""Tests for HypothesisBuffer — the core of LocalAgreement policy."""
+
+import pytest
+
+from whisperlivekit.timed_objects import ASRToken
+from whisperlivekit.local_agreement.online_asr import HypothesisBuffer
+
+
+def make_tokens(words, start=0.0, step=0.5):
+ """Helper: create ASRToken list from word strings."""
+ tokens = []
+ t = start
+ for w in words:
+ tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9))
+ t += step
+ return tokens
+
+
+class TestInsert:
+ def test_basic_insert(self):
+ buf = HypothesisBuffer()
+ tokens = make_tokens(["hello", "world"])
+ buf.insert(tokens, offset=0.0)
+ assert len(buf.new) == 2
+ assert buf.new[0].text == "hello"
+
+ def test_insert_with_offset(self):
+ buf = HypothesisBuffer()
+ tokens = make_tokens(["hello"], start=0.0)
+ buf.insert(tokens, offset=5.0)
+ assert buf.new[0].start == pytest.approx(5.0)
+
+ def test_insert_filters_old_tokens(self):
+ buf = HypothesisBuffer()
+ buf.last_committed_time = 10.0
+ tokens = make_tokens(["old", "new"], start=5.0, step=3.0)
+ buf.insert(tokens, offset=0.0)
+ # "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered
+ # "new" at 8.0 is also before 9.9 → filtered
+ assert len(buf.new) == 0
+
+ def test_insert_deduplicates_committed(self):
+ buf = HypothesisBuffer()
+ # Commit "hello"
+ tokens1 = make_tokens(["hello", "world"])
+ buf.insert(tokens1, offset=0.0)
+ buf.flush() # commits "hello" (buffer was empty, so nothing matches)
+ # Actually with empty buffer, flush won't commit anything
+ # Let's do it properly: two rounds
+ buf2 = HypothesisBuffer()
+ first = make_tokens(["hello", "world"])
+ buf2.insert(first, offset=0.0)
+ buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"]
+
+ second = make_tokens(["hello", "world", "test"])
+ buf2.insert(second, offset=0.0)
+ committed = buf2.flush()
+ # LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"]
+ assert len(committed) == 2
+ assert committed[0].text == "hello"
+ assert committed[1].text == "world"
+
+
+class TestFlush:
+ def test_flush_empty(self):
+ buf = HypothesisBuffer()
+ committed = buf.flush()
+ assert committed == []
+
+ def test_flush_lcp_matching(self):
+ buf = HypothesisBuffer()
+ # Round 1: establish buffer
+ buf.insert(make_tokens(["hello", "world"]), offset=0.0)
+ buf.flush() # buffer = ["hello", "world"], committed = []
+
+ # Round 2: same prefix, new suffix
+ buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
+ committed = buf.flush()
+ assert [t.text for t in committed] == ["hello", "world"]
+
+ def test_flush_no_match(self):
+ buf = HypothesisBuffer()
+ # Round 1
+ buf.insert(make_tokens(["hello", "world"]), offset=0.0)
+ buf.flush()
+
+ # Round 2: completely different
+ buf.insert(make_tokens(["foo", "bar"]), offset=0.0)
+ committed = buf.flush()
+ assert committed == []
+
+ def test_flush_partial_match(self):
+ buf = HypothesisBuffer()
+ buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
+ buf.flush()
+
+ buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0)
+ committed = buf.flush()
+ assert len(committed) == 1
+ assert committed[0].text == "hello"
+
+ def test_flush_updates_last_committed(self):
+ buf = HypothesisBuffer()
+ buf.insert(make_tokens(["hello", "world"]), offset=0.0)
+ buf.flush()
+
+ buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
+ buf.flush()
+ assert buf.last_committed_word == "world"
+ assert buf.last_committed_time > 0
+
+ def test_flush_with_confidence_validation(self):
+ buf = HypothesisBuffer(confidence_validation=True)
+ high_conf = [
+ ASRToken(start=0.0, end=0.5, text="sure", probability=0.99),
+ ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5),
+ ]
+ buf.insert(high_conf, offset=0.0)
+ committed = buf.flush()
+ # "sure" has p>0.95 → committed immediately
+ assert len(committed) == 1
+ assert committed[0].text == "sure"
+
+
+class TestPopCommitted:
+ def test_pop_removes_old(self):
+ buf = HypothesisBuffer()
+ buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0)
+ # "a": end=1.0, "b": end=2.0, "c": end=3.0
+ # pop_committed removes tokens with end <= time
+ buf.pop_committed(2.0)
+ # "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains
+ assert len(buf.committed_in_buffer) == 1
+ assert buf.committed_in_buffer[0].text == "c"
+
+ def test_pop_nothing(self):
+ buf = HypothesisBuffer()
+ buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0)
+ buf.pop_committed(0.0)
+ assert len(buf.committed_in_buffer) == 2
+
+ def test_pop_all(self):
+ buf = HypothesisBuffer()
+ buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5)
+ buf.pop_committed(100.0)
+ assert len(buf.committed_in_buffer) == 0
+
+
+class TestStreamingSimulation:
+ """Multi-round insert/flush simulating real streaming behavior."""
+
+ def test_three_rounds(self):
+ buf = HypothesisBuffer()
+ all_committed = []
+
+ # Round 1: "this is"
+ buf.insert(make_tokens(["this", "is"]), offset=0.0)
+ all_committed.extend(buf.flush())
+
+ # Round 2: "this is a test"
+ buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0)
+ all_committed.extend(buf.flush())
+
+ # Round 3: "this is a test today"
+ buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0)
+ all_committed.extend(buf.flush())
+
+ words = [t.text for t in all_committed]
+ assert "this" in words
+ assert "is" in words
+ assert "a" in words
+ assert "test" in words
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000..4412b32
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,183 @@
+"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
+
+import pytest
+
+from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
+
+
+class TestNormalizeText:
+ def test_lowercase(self):
+ assert normalize_text("Hello World") == "hello world"
+
+ def test_strip_punctuation(self):
+ assert normalize_text("Hello, world!") == "hello world"
+
+ def test_collapse_whitespace(self):
+ assert normalize_text(" hello world ") == "hello world"
+
+ def test_keep_hyphens(self):
+ assert normalize_text("real-time") == "real-time"
+
+ def test_keep_apostrophes(self):
+ assert normalize_text("don't") == "don't"
+
+ def test_unicode_normalized(self):
+ # e + combining accent should be same as precomposed
+ assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
+
+ def test_empty(self):
+ assert normalize_text("") == ""
+
+ def test_only_punctuation(self):
+ assert normalize_text("...!?") == ""
+
+
+class TestComputeWER:
+ def test_perfect_match(self):
+ result = compute_wer("hello world", "hello world")
+ assert result["wer"] == 0.0
+ assert result["substitutions"] == 0
+ assert result["insertions"] == 0
+ assert result["deletions"] == 0
+
+ def test_case_insensitive(self):
+ result = compute_wer("Hello World", "hello world")
+ assert result["wer"] == 0.0
+
+ def test_punctuation_ignored(self):
+ result = compute_wer("Hello, world!", "hello world")
+ assert result["wer"] == 0.0
+
+ def test_one_substitution(self):
+ result = compute_wer("hello world", "hello earth")
+ assert result["wer"] == pytest.approx(0.5)
+ assert result["substitutions"] == 1
+
+ def test_one_insertion(self):
+ result = compute_wer("hello world", "hello big world")
+ assert result["wer"] == pytest.approx(0.5)
+ assert result["insertions"] == 1
+
+ def test_one_deletion(self):
+ result = compute_wer("hello big world", "hello world")
+ assert result["wer"] == pytest.approx(1 / 3)
+ assert result["deletions"] == 1
+
+ def test_completely_different(self):
+ result = compute_wer("the cat sat", "a dog ran")
+ assert result["wer"] == pytest.approx(1.0)
+
+ def test_empty_reference(self):
+ result = compute_wer("", "hello")
+ assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
+ assert result["ref_words"] == 0
+
+ def test_empty_hypothesis(self):
+ result = compute_wer("hello world", "")
+ assert result["wer"] == pytest.approx(1.0)
+ assert result["deletions"] == 2
+
+ def test_both_empty(self):
+ result = compute_wer("", "")
+ assert result["wer"] == 0.0
+
+ def test_ref_and_hyp_word_counts(self):
+ result = compute_wer("one two three", "one two three four")
+ assert result["ref_words"] == 3
+ assert result["hyp_words"] == 4
+
+
+class TestComputeTimestampAccuracy:
+ def test_perfect_match(self):
+ words = [
+ {"word": "hello", "start": 0.0, "end": 0.5},
+ {"word": "world", "start": 0.5, "end": 1.0},
+ ]
+ result = compute_timestamp_accuracy(words, words)
+ assert result["mae_start"] == 0.0
+ assert result["max_delta_start"] == 0.0
+ assert result["n_matched"] == 2
+
+ def test_constant_offset(self):
+ ref = [
+ {"word": "hello", "start": 0.0, "end": 0.5},
+ {"word": "world", "start": 0.5, "end": 1.0},
+ ]
+ pred = [
+ {"word": "hello", "start": 0.1, "end": 0.6},
+ {"word": "world", "start": 0.6, "end": 1.1},
+ ]
+ result = compute_timestamp_accuracy(pred, ref)
+ assert result["mae_start"] == pytest.approx(0.1)
+ assert result["max_delta_start"] == pytest.approx(0.1)
+ assert result["n_matched"] == 2
+
+ def test_mismatched_word_counts(self):
+ ref = [
+ {"word": "hello", "start": 0.0, "end": 0.5},
+ {"word": "beautiful", "start": 0.5, "end": 1.0},
+ {"word": "world", "start": 1.0, "end": 1.5},
+ ]
+ pred = [
+ {"word": "hello", "start": 0.0, "end": 0.5},
+ {"word": "world", "start": 1.1, "end": 1.6},
+ ]
+ result = compute_timestamp_accuracy(pred, ref)
+ assert result["n_matched"] == 2
+ assert result["n_ref"] == 3
+ assert result["n_pred"] == 2
+
+ def test_empty_predicted(self):
+ ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
+ result = compute_timestamp_accuracy([], ref)
+ assert result["mae_start"] is None
+ assert result["n_matched"] == 0
+
+ def test_empty_reference(self):
+ pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
+ result = compute_timestamp_accuracy(pred, [])
+ assert result["mae_start"] is None
+ assert result["n_matched"] == 0
+
+ def test_case_insensitive_matching(self):
+ ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
+ pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
+ result = compute_timestamp_accuracy(pred, ref)
+ assert result["n_matched"] == 1
+ assert result["mae_start"] == pytest.approx(0.1)
+
+ def test_median_even_count(self):
+ """Median with even number of matched words should average the two middle values."""
+ ref = [
+ {"word": "a", "start": 0.0, "end": 0.2},
+ {"word": "b", "start": 0.5, "end": 0.7},
+ {"word": "c", "start": 1.0, "end": 1.2},
+ {"word": "d", "start": 1.5, "end": 1.7},
+ ]
+ pred = [
+ {"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
+ {"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
+ {"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
+ {"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
+ ]
+ result = compute_timestamp_accuracy(pred, ref)
+ assert result["n_matched"] == 4
+ # sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
+ assert result["median_delta_start"] == pytest.approx(0.25)
+
+ def test_median_odd_count(self):
+ """Median with odd number of matched words takes the middle value."""
+ ref = [
+ {"word": "a", "start": 0.0, "end": 0.2},
+ {"word": "b", "start": 0.5, "end": 0.7},
+ {"word": "c", "start": 1.0, "end": 1.2},
+ ]
+ pred = [
+ {"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
+ {"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
+ {"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
+ ]
+ result = compute_timestamp_accuracy(pred, ref)
+ assert result["n_matched"] == 3
+ # sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
+ assert result["median_delta_start"] == pytest.approx(0.2)
diff --git a/tests/test_silence_handling.py b/tests/test_silence_handling.py
new file mode 100644
index 0000000..08028be
--- /dev/null
+++ b/tests/test_silence_handling.py
@@ -0,0 +1,99 @@
+"""Tests for silence handling — state machine and double-counting regression."""
+
+import pytest
+
+from whisperlivekit.timed_objects import Silence
+
+
+class TestSilenceStateMachine:
+ """Test Silence object state transitions."""
+
+ def test_initial_state(self):
+ s = Silence(start=1.0, is_starting=True)
+ assert s.is_starting is True
+ assert s.has_ended is False
+ assert s.duration is None
+ assert s.end is None
+
+ def test_end_silence(self):
+ s = Silence(start=1.0, is_starting=True)
+ s.end = 3.0
+ s.is_starting = False
+ s.has_ended = True
+ s.compute_duration()
+ assert s.duration == pytest.approx(2.0)
+
+ def test_very_short_silence(self):
+ s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True)
+ s.compute_duration()
+ assert s.duration == pytest.approx(0.01)
+
+ def test_zero_duration_silence(self):
+ s = Silence(start=5.0, end=5.0)
+ s.compute_duration()
+ assert s.duration == pytest.approx(0.0)
+
+
+class TestSilenceDoubleCounting:
+ """Regression tests for the silence double-counting bug.
+
+ The bug: _begin_silence and _end_silence both pushed self.current_silence
+ to the queue. Since they were the same Python object, _end_silence's mutation
+ affected the already-queued start event. The consumer processed both as
+ ended silences, doubling the duration.
+
+ Fix: _begin_silence now pushes a separate Silence object for the start event.
+ """
+
+ def test_start_and_end_are_separate_objects(self):
+ """Simulate the fix: start event and end event must be different objects."""
+ # Simulate _begin_silence: creates start event as separate object
+ current_silence = Silence(start=1.0, is_starting=True)
+ start_event = Silence(start=1.0, is_starting=True) # separate copy
+
+ # Simulate _end_silence: mutates current_silence
+ current_silence.end = 3.0
+ current_silence.is_starting = False
+ current_silence.has_ended = True
+ current_silence.compute_duration()
+
+ # start_event should NOT be affected by mutations to current_silence
+ assert start_event.is_starting is True
+ assert start_event.has_ended is False
+ assert start_event.end is None
+
+ # current_silence (end event) has the final state
+ assert current_silence.has_ended is True
+ assert current_silence.duration == pytest.approx(2.0)
+
+ def test_single_object_would_cause_double_counting(self):
+ """Demonstrate the bug: if same object is used for both events."""
+ shared = Silence(start=1.0, is_starting=True)
+ queue = [shared] # start event queued
+
+ # Mutate (simulates _end_silence)
+ shared.end = 3.0
+ shared.is_starting = False
+ shared.has_ended = True
+ shared.compute_duration()
+ queue.append(shared) # end event queued
+
+ # Both queue items point to the SAME mutated object
+ assert queue[0] is queue[1] # same reference
+ assert queue[0].has_ended is True # start event also shows ended!
+
+ # This would cause double-counting: both items have has_ended=True
+ # and duration=2.0, so the consumer adds 2.0 twice = 4.0
+
+
+class TestConsecutiveSilences:
+ def test_multiple_silences(self):
+ """Multiple silence periods should have independent durations."""
+ s1 = Silence(start=1.0, end=2.0)
+ s1.compute_duration()
+ s2 = Silence(start=5.0, end=8.0)
+ s2.compute_duration()
+ assert s1.duration == pytest.approx(1.0)
+ assert s2.duration == pytest.approx(3.0)
+ # Total silence should be sum, not accumulated on single object
+ assert s1.duration + s2.duration == pytest.approx(4.0)
diff --git a/tests/test_timed_objects.py b/tests/test_timed_objects.py
new file mode 100644
index 0000000..559a1c3
--- /dev/null
+++ b/tests/test_timed_objects.py
@@ -0,0 +1,185 @@
+"""Tests for whisperlivekit.timed_objects data classes."""
+
+import pytest
+
+from whisperlivekit.timed_objects import (
+ ASRToken,
+ FrontData,
+ Segment,
+ Silence,
+ TimedText,
+ Transcript,
+ format_time,
+)
+
+
+class TestFormatTime:
+ def test_zero(self):
+ assert format_time(0) == "0:00:00"
+
+ def test_one_minute(self):
+ assert format_time(60) == "0:01:00"
+
+ def test_one_hour(self):
+ assert format_time(3600) == "1:00:00"
+
+ def test_fractional_truncated(self):
+ assert format_time(61.9) == "0:01:01"
+
+
+class TestASRToken:
+ def test_with_offset(self):
+ t = ASRToken(start=1.0, end=2.0, text="hello")
+ shifted = t.with_offset(0.5)
+ assert shifted.start == pytest.approx(1.5)
+ assert shifted.end == pytest.approx(2.5)
+ assert shifted.text == "hello"
+
+ def test_with_offset_preserves_fields(self):
+ t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95)
+ shifted = t.with_offset(1.0)
+ assert shifted.speaker == 2
+ assert shifted.probability == 0.95
+
+ def test_is_silence_false(self):
+ t = ASRToken(start=0.0, end=1.0, text="hello")
+ assert t.is_silence() is False
+
+ def test_bool_truthy(self):
+ t = ASRToken(start=0.0, end=1.0, text="hello")
+ assert bool(t) is True
+
+ def test_bool_falsy(self):
+ t = ASRToken(start=0.0, end=1.0, text="")
+ assert bool(t) is False
+
+
+class TestTimedText:
+ def test_has_punctuation_period(self):
+ t = TimedText(text="hello.")
+ assert t.has_punctuation() is True
+
+ def test_has_punctuation_exclamation(self):
+ t = TimedText(text="wow!")
+ assert t.has_punctuation() is True
+
+ def test_has_punctuation_question(self):
+ t = TimedText(text="really?")
+ assert t.has_punctuation() is True
+
+ def test_has_punctuation_cjk(self):
+ t = TimedText(text="hello。")
+ assert t.has_punctuation() is True
+
+ def test_no_punctuation(self):
+ t = TimedText(text="hello world")
+ assert t.has_punctuation() is False
+
+ def test_duration(self):
+ t = TimedText(start=1.0, end=3.5)
+ assert t.duration() == pytest.approx(2.5)
+
+ def test_contains_timespan(self):
+ outer = TimedText(start=0.0, end=5.0)
+ inner = TimedText(start=1.0, end=3.0)
+ assert outer.contains_timespan(inner) is True
+ assert inner.contains_timespan(outer) is False
+
+
+class TestSilence:
+ def test_compute_duration(self):
+ s = Silence(start=1.0, end=3.5)
+ d = s.compute_duration()
+ assert d == pytest.approx(2.5)
+ assert s.duration == pytest.approx(2.5)
+
+ def test_compute_duration_none_start(self):
+ s = Silence(start=None, end=3.5)
+ d = s.compute_duration()
+ assert d is None
+
+ def test_compute_duration_none_end(self):
+ s = Silence(start=1.0, end=None)
+ d = s.compute_duration()
+ assert d is None
+
+ def test_is_silence_true(self):
+ s = Silence()
+ assert s.is_silence() is True
+
+
+class TestTranscript:
+ def test_from_tokens(self, sample_tokens):
+ t = Transcript.from_tokens(sample_tokens, sep="")
+ assert t.text == "Hello world test."
+ assert t.start == pytest.approx(0.0)
+ assert t.end == pytest.approx(1.5)
+
+ def test_from_tokens_with_sep(self, sample_tokens):
+ t = Transcript.from_tokens(sample_tokens, sep="|")
+ assert t.text == "Hello| world| test."
+
+ def test_from_empty_tokens(self):
+ t = Transcript.from_tokens([])
+ assert t.text == ""
+ assert t.start is None
+ assert t.end is None
+
+ def test_from_tokens_with_offset(self, sample_tokens):
+ t = Transcript.from_tokens(sample_tokens, offset=10.0)
+ assert t.start == pytest.approx(10.0)
+ assert t.end == pytest.approx(11.5)
+
+
+class TestSegment:
+ def test_from_tokens(self, sample_tokens):
+ seg = Segment.from_tokens(sample_tokens)
+ assert seg is not None
+ assert seg.text == "Hello world test."
+ assert seg.start == pytest.approx(0.0)
+ assert seg.end == pytest.approx(1.5)
+ assert seg.speaker == -1
+
+ def test_from_silence_tokens(self):
+ silences = [
+ Silence(start=1.0, end=2.0),
+ Silence(start=2.0, end=3.0),
+ ]
+ seg = Segment.from_tokens(silences, is_silence=True)
+ assert seg is not None
+ assert seg.speaker == -2
+ assert seg.is_silence() is True
+ assert seg.text is None
+
+ def test_from_empty_tokens(self):
+ seg = Segment.from_tokens([])
+ assert seg is None
+
+ def test_to_dict(self, sample_tokens):
+ seg = Segment.from_tokens(sample_tokens)
+ d = seg.to_dict()
+ assert "text" in d
+ assert "speaker" in d
+ assert "start" in d
+ assert "end" in d
+
+
+class TestFrontData:
+ def test_to_dict_empty(self):
+ fd = FrontData()
+ d = fd.to_dict()
+ assert d["lines"] == []
+ assert d["buffer_transcription"] == ""
+ assert "error" not in d
+
+ def test_to_dict_with_error(self):
+ fd = FrontData(error="something broke")
+ d = fd.to_dict()
+ assert d["error"] == "something broke"
+
+ def test_to_dict_with_lines(self, sample_tokens):
+ seg = Segment.from_tokens(sample_tokens)
+ fd = FrontData(lines=[seg])
+ d = fd.to_dict()
+ assert len(d["lines"]) == 1
+ assert d["lines"][0]["text"] == "Hello world test."
diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py
index 37c6a44..3d43c03 100644
--- a/whisperlivekit/audio_processor.py
+++ b/whisperlivekit/audio_processor.py
@@ -9,6 +9,7 @@ import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
+from whisperlivekit.metrics_collector import SessionMetrics
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
@@ -118,6 +119,7 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
+ self.metrics: SessionMetrics = SessionMetrics()
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
@@ -139,25 +141,43 @@ class AudioProcessor:
if self.translation_queue:
await self.translation_queue.put(self.current_silence)
- async def _begin_silence(self) -> None:
+ async def _begin_silence(self, at_sample: Optional[int] = None) -> None:
if self.current_silence:
return
- now = time() - self.beg_loop
+ # Use audio stream time (sample-precise) for accurate silence duration
+ if at_sample is not None:
+ audio_t = at_sample / self.sample_rate
+ else:
+ audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
self.current_silence = Silence(
- is_starting=True, start=now
+ is_starting=True, start=audio_t
)
- await self._push_silence_event()
+ # Push a separate start-only event so _end_silence won't mutate it
+ start_event = Silence(is_starting=True, start=audio_t)
+ if self.transcription_queue:
+ await self.transcription_queue.put(start_event)
+ if self.args.diarization and self.diarization_queue:
+ await self.diarization_queue.put(start_event)
+ if self.translation_queue:
+ await self.translation_queue.put(start_event)
- async def _end_silence(self) -> None:
+ async def _end_silence(self, at_sample: Optional[int] = None) -> None:
if not self.current_silence:
return
- now = time() - self.beg_loop
- self.current_silence.end = now
- self.current_silence.is_starting=False
- self.current_silence.has_ended=True
+ if at_sample is not None:
+ audio_t = at_sample / self.sample_rate
+ else:
+ audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
+ self.current_silence.end = audio_t
+ self.current_silence.is_starting = False
+ self.current_silence.has_ended = True
self.current_silence.compute_duration()
+ self.metrics.n_silence_events += 1
+ if self.current_silence.duration is not None:
+ self.metrics.total_silence_duration_s += self.current_silence.duration
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
+ # Push the completed silence as the end event (separate from the start event)
await self._push_silence_event()
self.current_silence = None
@@ -253,6 +273,34 @@ class AudioProcessor:
if self.translation:
await self.translation_queue.put(SENTINEL)
+ async def _finish_transcription(self) -> None:
+ """Call finish() on the online processor to flush remaining tokens."""
+ if not self.transcription:
+ return
+ try:
+ if hasattr(self.transcription, 'finish'):
+ final_tokens, end_time = await asyncio.to_thread(self.transcription.finish)
+ else:
+ # SimulStreamingOnlineProcessor uses start_silence() → process_iter(is_last=True)
+ final_tokens, end_time = await asyncio.to_thread(self.transcription.start_silence)
+
+ final_tokens = final_tokens or []
+ if final_tokens:
+ logger.info(f"Finish flushed {len(final_tokens)} tokens")
+ _buffer_transcript = self.transcription.get_buffer()
+ async with self.lock:
+ self.state.tokens.extend(final_tokens)
+ self.state.buffer_transcription = _buffer_transcript
+ self.state.end_buffer = max(self.state.end_buffer, end_time)
+ self.state.new_tokens.extend(final_tokens)
+ self.state.new_tokens_buffer = _buffer_transcript
+ if self.translation_queue:
+ for token in final_tokens:
+ await self.translation_queue.put(token)
+ except Exception as e:
+ logger.warning(f"Error finishing transcription: {e}")
+ logger.debug(f"Traceback: {traceback.format_exc()}")
+
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
@@ -263,6 +311,7 @@ class AudioProcessor:
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
+ await self._finish_transcription()
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
@@ -297,8 +346,13 @@ class AudioProcessor:
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
+ _t0 = time()
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
+ _dur = time() - _t0
+ self.metrics.transcription_durations.append(_dur)
+ self.metrics.n_transcription_calls += 1
new_tokens = new_tokens or []
+ self.metrics.n_tokens_produced += len(new_tokens)
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
@@ -433,6 +487,7 @@ class AudioProcessor:
should_push = (response != self.last_response_content)
if should_push:
+ self.metrics.n_responses_sent += 1
yield response
self.last_response_content = response
@@ -535,6 +590,10 @@ class AudioProcessor:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.diarization:
self.diarization.close()
+
+ # Finalize session metrics
+ self.metrics.total_audio_duration_s = self.total_pcm_samples / self.sample_rate
+ self.metrics.log_summary()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self) -> bool:
@@ -553,6 +612,7 @@ class AudioProcessor:
if not self.beg_loop:
self.beg_loop = time()
+ self.metrics.session_start = self.beg_loop
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
@@ -560,6 +620,10 @@ class AudioProcessor:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
+ # Flush any remaining PCM data before signaling end-of-stream
+ if self.is_pcm_input and self.pcm_buffer:
+ await self._flush_remaining_pcm()
+
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
@@ -572,6 +636,8 @@ class AudioProcessor:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return
+ self.metrics.n_chunks_received += 1
+
if self.is_pcm_input:
self.pcm_buffer.extend(message)
await self.handle_pcm_data()
@@ -588,6 +654,11 @@ class AudioProcessor:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self) -> None:
+ # Without VAC, there's no speech detector to end the initial silence.
+ # Clear it on the first audio chunk so audio actually gets enqueued.
+ if not self.args.vac and self.current_silence:
+ await self._end_silence()
+
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
@@ -616,7 +687,7 @@ class AudioProcessor:
if res is not None:
if "start" in res and self.current_silence:
- await self._end_silence()
+ await self._end_silence(at_sample=res.get("start"))
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
@@ -624,7 +695,7 @@ class AudioProcessor:
)
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
- await self._begin_silence()
+ await self._begin_silence(at_sample=res.get("end"))
if not self.current_silence:
await self._enqueue_active_audio(pcm_array)
@@ -633,3 +704,21 @@ class AudioProcessor:
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
+
+ async def _flush_remaining_pcm(self) -> None:
+ """Flush whatever PCM data remains in the buffer, regardless of size threshold."""
+ if not self.pcm_buffer:
+ return
+ aligned_size = (len(self.pcm_buffer) // self.bytes_per_sample) * self.bytes_per_sample
+ if aligned_size == 0:
+ return
+ pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_size])
+ self.pcm_buffer = self.pcm_buffer[aligned_size:]
+
+ # End any active silence so the audio gets enqueued
+ if self.current_silence:
+ await self._end_silence(at_sample=self.total_pcm_samples)
+
+ await self._enqueue_active_audio(pcm_array)
+ self.total_pcm_samples += len(pcm_array)
+ logger.info(f"Flushed remaining PCM buffer: {len(pcm_array)} samples ({len(pcm_array)/self.sample_rate:.2f}s)")
diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py
index 7cf1041..c306f52 100644
--- a/whisperlivekit/core.py
+++ b/whisperlivekit/core.py
@@ -92,7 +92,12 @@ class TranscriptionEngine:
}
if config.transcription:
- if config.backend == "voxtral":
+ if config.backend == "voxtral-mlx":
+ from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
+ self.tokenizer = None
+ self.asr = VoxtralMLXASR(**transcription_common_params)
+ logger.info("Using Voxtral MLX native backend")
+ elif config.backend == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
@@ -169,6 +174,9 @@ class TranscriptionEngine:
def online_factory(args, asr):
+ if getattr(args, 'backend', None) == "voxtral-mlx":
+ from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
+ return VoxtralMLXOnlineProcessor(asr)
if getattr(args, 'backend', None) == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
return VoxtralHFStreamingOnlineProcessor(asr)
diff --git a/whisperlivekit/metrics.py b/whisperlivekit/metrics.py
new file mode 100644
index 0000000..8bbd9af
--- /dev/null
+++ b/whisperlivekit/metrics.py
@@ -0,0 +1,156 @@
+"""Lightweight ASR evaluation metrics — no external dependencies.
+
+Provides WER (Word Error Rate) computation via word-level Levenshtein distance,
+text normalization, and word-level timestamp accuracy metrics with greedy alignment.
+"""
+
+import re
+import unicodedata
+from typing import Dict, List, Optional
+
+
+def normalize_text(text: str) -> str:
+ """Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace."""
+ text = text.lower()
+ # Normalize unicode (e.g., accented chars to composed form)
+ text = unicodedata.normalize("NFC", text)
+ # Remove punctuation (keep letters, numbers, spaces, hyphens within words)
+ text = re.sub(r"[^\w\s\-']", " ", text)
+ # Collapse whitespace
+ text = re.sub(r"\s+", " ", text).strip()
+ return text
+
+
+def compute_wer(reference: str, hypothesis: str) -> Dict:
+ """Compute Word Error Rate using word-level Levenshtein edit distance.
+
+ Args:
+ reference: Ground truth transcription.
+ hypothesis: Predicted transcription.
+
+ Returns:
+ Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words.
+ WER can exceed 1.0 if there are more errors than reference words.
+ """
+ ref_words = normalize_text(reference).split()
+ hyp_words = normalize_text(hypothesis).split()
+
+ n = len(ref_words)
+ m = len(hyp_words)
+
+ if n == 0:
+ return {
+ "wer": 0.0 if m == 0 else float(m),
+ "substitutions": 0,
+ "insertions": m,
+ "deletions": 0,
+ "ref_words": 0,
+ "hyp_words": m,
+ }
+
+ # DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions)
+ dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)]
+
+ for i in range(1, n + 1):
+ dp[i][0] = (i, 0, 0, i)
+ for j in range(1, m + 1):
+ dp[0][j] = (j, 0, j, 0)
+
+ for i in range(1, n + 1):
+ for j in range(1, m + 1):
+ if ref_words[i - 1] == hyp_words[j - 1]:
+ dp[i][j] = dp[i - 1][j - 1]
+ else:
+ sub = dp[i - 1][j - 1]
+ ins = dp[i][j - 1]
+ dele = dp[i - 1][j]
+
+ sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3])
+ ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3])
+ del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1)
+
+ dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0])
+
+ dist, subs, ins, dels = dp[n][m]
+ return {
+ "wer": dist / n,
+ "substitutions": subs,
+ "insertions": ins,
+ "deletions": dels,
+ "ref_words": n,
+ "hyp_words": m,
+ }
+
+
+def compute_timestamp_accuracy(
+ predicted: List[Dict],
+ reference: List[Dict],
+) -> Dict:
+ """Compute timestamp accuracy by aligning predicted words to reference words.
+
+ Uses greedy left-to-right alignment on normalized text. For each matched pair,
+ computes the start-time delta (predicted - reference).
+
+ Args:
+ predicted: List of dicts with keys: word, start, end.
+ reference: List of dicts with keys: word, start, end.
+
+ Returns:
+ Dict with keys: mae_start, max_delta_start, median_delta_start,
+ n_matched, n_ref, n_pred. Returns None values if no matches found.
+ """
+ if not predicted or not reference:
+ return {
+ "mae_start": None,
+ "max_delta_start": None,
+ "median_delta_start": None,
+ "n_matched": 0,
+ "n_ref": len(reference),
+ "n_pred": len(predicted),
+ }
+
+ # Normalize words for matching
+ pred_norm = [normalize_text(p["word"]) for p in predicted]
+ ref_norm = [normalize_text(r["word"]) for r in reference]
+
+ # Greedy left-to-right alignment
+ deltas_start = []
+ ref_idx = 0
+ for p_idx, p_word in enumerate(pred_norm):
+ if not p_word:
+ continue
+ # Scan forward in reference to find a match (allow small skips)
+ search_limit = min(ref_idx + 3, len(ref_norm))
+ for r_idx in range(ref_idx, search_limit):
+ if ref_norm[r_idx] == p_word:
+ delta = predicted[p_idx]["start"] - reference[r_idx]["start"]
+ deltas_start.append(delta)
+ ref_idx = r_idx + 1
+ break
+
+ if not deltas_start:
+ return {
+ "mae_start": None,
+ "max_delta_start": None,
+ "median_delta_start": None,
+ "n_matched": 0,
+ "n_ref": len(reference),
+ "n_pred": len(predicted),
+ }
+
+ abs_deltas = [abs(d) for d in deltas_start]
+ sorted_abs = sorted(abs_deltas)
+ n = len(sorted_abs)
+ if n % 2 == 1:
+ median = sorted_abs[n // 2]
+ else:
+ median = (sorted_abs[n // 2 - 1] + sorted_abs[n // 2]) / 2
+
+ return {
+ "mae_start": sum(abs_deltas) / len(abs_deltas),
+ "max_delta_start": max(abs_deltas),
+ "median_delta_start": median,
+ "n_matched": len(deltas_start),
+ "n_ref": len(reference),
+ "n_pred": len(predicted),
+ }
diff --git a/whisperlivekit/metrics_collector.py b/whisperlivekit/metrics_collector.py
new file mode 100644
index 0000000..03db5dc
--- /dev/null
+++ b/whisperlivekit/metrics_collector.py
@@ -0,0 +1,84 @@
+"""Lightweight runtime metrics for AudioProcessor sessions.
+
+Zero external dependencies. Negligible overhead when not queried —
+just integer increments and list appends during normal operation.
+"""
+
+import logging
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SessionMetrics:
+ """Per-session metrics collected by AudioProcessor."""
+
+ session_start: float = 0.0
+ total_audio_duration_s: float = 0.0
+ total_processing_time_s: float = 0.0
+
+ # Chunk / call counters
+ n_chunks_received: int = 0
+ n_transcription_calls: int = 0
+ n_tokens_produced: int = 0
+ n_responses_sent: int = 0
+
+ # Per-call ASR latency (seconds)
+ transcription_durations: List[float] = field(default_factory=list)
+
+ # Silence
+ n_silence_events: int = 0
+ total_silence_duration_s: float = 0.0
+
+ # --- Computed properties ---
+
+ @property
+ def rtf(self) -> float:
+ """Real-time factor: processing_time / audio_duration."""
+ if self.total_audio_duration_s <= 0:
+ return 0.0
+ return self.total_processing_time_s / self.total_audio_duration_s
+
+ @property
+ def avg_latency_ms(self) -> float:
+ """Average per-call ASR latency in milliseconds."""
+ if not self.transcription_durations:
+ return 0.0
+ return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000
+
+ @property
+ def p95_latency_ms(self) -> float:
+ """95th percentile per-call ASR latency in milliseconds."""
+ if not self.transcription_durations:
+ return 0.0
+ sorted_d = sorted(self.transcription_durations)
+ idx = int(len(sorted_d) * 0.95)
+ idx = min(idx, len(sorted_d) - 1)
+ return sorted_d[idx] * 1000
+
+ def to_dict(self) -> Dict:
+ """Serialize to a plain dict (JSON-safe)."""
+ return {
+ "session_start": self.session_start,
+ "total_audio_duration_s": round(self.total_audio_duration_s, 3),
+ "total_processing_time_s": round(self.total_processing_time_s, 3),
+ "rtf": round(self.rtf, 3),
+ "n_chunks_received": self.n_chunks_received,
+ "n_transcription_calls": self.n_transcription_calls,
+ "n_tokens_produced": self.n_tokens_produced,
+ "n_responses_sent": self.n_responses_sent,
+ "avg_latency_ms": round(self.avg_latency_ms, 2),
+ "p95_latency_ms": round(self.p95_latency_ms, 2),
+ "n_silence_events": self.n_silence_events,
+ "total_silence_duration_s": round(self.total_silence_duration_s, 3),
+ }
+
+ def log_summary(self) -> None:
+ """Emit a structured log line summarising the session."""
+ self.total_processing_time_s = sum(self.transcription_durations)
+ d = self.to_dict()
+ d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
+ logger.info(f"SESSION_METRICS {d}")
diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py
index d89aaca..94518d7 100644
--- a/whisperlivekit/parse_args.py
+++ b/whisperlivekit/parse_args.py
@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
- choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral"],
- help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for Voxtral streaming via HuggingFace Transformers (CUDA/CPU/MPS).",
+ choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
+ help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
)
parser.add_argument(
"--no-vac",
diff --git a/whisperlivekit/voxtral_hf_streaming.py b/whisperlivekit/voxtral_hf_streaming.py
index 2fee95f..89ffbd7 100644
--- a/whisperlivekit/voxtral_hf_streaming.py
+++ b/whisperlivekit/voxtral_hf_streaming.py
@@ -85,10 +85,11 @@ class VoxtralHFStreamingOnlineProcessor:
processor = asr.processor
self._first_chunk_samples = processor.num_samples_first_audio_chunk
self._chunk_samples = processor.num_samples_per_audio_chunk
- self._chunk_step = processor.num_samples_per_audio_chunk_step
- self._right_pad_samples = int(
- processor.num_right_pad_tokens * processor.raw_audio_length_per_tok
- )
+ self._chunk_step = processor.raw_audio_length_per_tok
+ n_right_pad = processor.num_right_pad_tokens
+ if callable(n_right_pad):
+ n_right_pad = n_right_pad()
+ self._right_pad_samples = int(n_right_pad * processor.raw_audio_length_per_tok)
self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE
self._reset_state()
@@ -238,10 +239,16 @@ class VoxtralHFStreamingOnlineProcessor:
def run_generate():
try:
with torch.no_grad():
+ # Pass generator as input_features — the model detects GeneratorType
+ # and internally converts it to input_features_generator
+ generate_kwargs = {
+ k: v for k, v in first_inputs.items()
+ if k != "input_features"
+ }
model.generate(
- input_features_generator=input_features_gen(),
+ input_features=input_features_gen(),
streamer=streamer,
- **first_inputs,
+ **generate_kwargs,
)
except Exception as e:
logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True)
@@ -271,18 +278,20 @@ class VoxtralHFStreamingOnlineProcessor:
if not self._generate_started:
return
- streamer = self._streamer
- try:
- for text_fragment in streamer:
- if text_fragment:
- with self._text_lock:
- self._accumulated_text += text_fragment
- self._n_text_tokens_received += 1
- # Check if more is immediately available (non-blocking)
- if streamer.text_queue.empty():
- break
- except StopIteration:
- pass
+ text_queue = self._streamer.text_queue
+ while True:
+ try:
+ text_fragment = text_queue.get_nowait()
+ except queue.Empty:
+ break
+ # TextIteratorStreamer uses None as end-of-stream sentinel
+ if text_fragment is None:
+ self._generate_finished = True
+ break
+ if text_fragment:
+ with self._text_lock:
+ self._accumulated_text += text_fragment
+ self._n_text_tokens_received += 1
# ── Word extraction ──
diff --git a/whisperlivekit/voxtral_mlx/__init__.py b/whisperlivekit/voxtral_mlx/__init__.py
new file mode 100644
index 0000000..d008c70
--- /dev/null
+++ b/whisperlivekit/voxtral_mlx/__init__.py
@@ -0,0 +1,6 @@
+"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
+
+from .loader import load_voxtral_model
+from .model import VoxtralMLXModel
+
+__all__ = ["load_voxtral_model", "VoxtralMLXModel"]
diff --git a/whisperlivekit/voxtral_mlx/loader.py b/whisperlivekit/voxtral_mlx/loader.py
new file mode 100644
index 0000000..486bd71
--- /dev/null
+++ b/whisperlivekit/voxtral_mlx/loader.py
@@ -0,0 +1,282 @@
+"""
+Model weight loading for the MLX Voxtral Realtime backend.
+
+Supports two on-disk formats:
+ 1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
+ with optional quantisation metadata.
+ 2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
+ requires weight renaming and conv-weight transposition.
+
+The public entry point is :func:`load_voxtral_model` which returns the
+model, tokenizer, and raw config dict.
+"""
+
+import json
+import logging
+import re
+from pathlib import Path
+
+import mlx.core as mx
+import mlx.nn as nn
+from huggingface_hub import snapshot_download
+
+from .model import VoxtralMLXModel
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
+
+# ---------------------------------------------------------------------------
+# Downloading
+# ---------------------------------------------------------------------------
+
+_ALLOWED_PATTERNS = [
+ "consolidated.safetensors",
+ "model*.safetensors",
+ "model.safetensors.index.json",
+ "params.json",
+ "config.json",
+ "tekken.json",
+]
+
+
+def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
+ """Download model files from HuggingFace Hub and return the local path."""
+ return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
+
+
+# ---------------------------------------------------------------------------
+# Weight name remapping (Mistral → our naming)
+# ---------------------------------------------------------------------------
+
+_NAME_RULES: list[tuple[str, str]] = [
+ # Encoder convolutions
+ (r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
+ (r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
+ # Encoder transformer blocks
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
+ r"encoder.blocks.\1.self_attn.q_proj.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
+ r"encoder.blocks.\1.self_attn.k_proj.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
+ r"encoder.blocks.\1.self_attn.v_proj.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
+ r"encoder.blocks.\1.self_attn.out_proj.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
+ r"encoder.blocks.\1.pre_attn_norm.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
+ r"encoder.blocks.\1.ffn.gate.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
+ r"encoder.blocks.\1.ffn.down.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
+ r"encoder.blocks.\1.ffn.up.\2"),
+ (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
+ r"encoder.blocks.\1.pre_ffn_norm.\2"),
+ (r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
+ # Adapter
+ (r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
+ (r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
+ # Decoder embedding
+ (r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
+ # Decoder blocks
+ (r"layers\.(\d+)\.attention\.wq\.weight",
+ r"decoder.blocks.\1.self_attn.q_proj.weight"),
+ (r"layers\.(\d+)\.attention\.wk\.weight",
+ r"decoder.blocks.\1.self_attn.k_proj.weight"),
+ (r"layers\.(\d+)\.attention\.wv\.weight",
+ r"decoder.blocks.\1.self_attn.v_proj.weight"),
+ (r"layers\.(\d+)\.attention\.wo\.weight",
+ r"decoder.blocks.\1.self_attn.out_proj.weight"),
+ (r"layers\.(\d+)\.attention_norm\.weight",
+ r"decoder.blocks.\1.pre_attn_norm.weight"),
+ (r"layers\.(\d+)\.feed_forward\.w1\.weight",
+ r"decoder.blocks.\1.ffn.gate.weight"),
+ (r"layers\.(\d+)\.feed_forward\.w2\.weight",
+ r"decoder.blocks.\1.ffn.down.weight"),
+ (r"layers\.(\d+)\.feed_forward\.w3\.weight",
+ r"decoder.blocks.\1.ffn.up.weight"),
+ (r"layers\.(\d+)\.ffn_norm\.weight",
+ r"decoder.blocks.\1.pre_ffn_norm.weight"),
+ (r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
+ r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
+ (r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
+ r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
+ # Decoder final norm
+ (r"norm\.weight", r"decoder.final_norm.weight"),
+]
+
+_PREFIX_STRIP = re.compile(
+ r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
+)
+
+
+def _translate_weight_name(name: str) -> str | None:
+ name = _PREFIX_STRIP.sub("", name)
+ for pattern, replacement in _NAME_RULES:
+ result, n = re.subn(f"^{pattern}$", replacement, name)
+ if n:
+ return result
+ return None
+
+
+def _is_conv_weight(name: str) -> bool:
+ return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
+
+
+# ---------------------------------------------------------------------------
+# Converted-format weight remapping (voxmlx names → our names)
+# ---------------------------------------------------------------------------
+
+_CONVERTED_RULES: list[tuple[str, str]] = [
+ # Adapter
+ (r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
+ (r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
+ # Encoder transformer blocks
+ (r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
+ (r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
+ (r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
+ (r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
+ (r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
+ (r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
+ (r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
+ # Decoder embedding
+ (r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
+ # Decoder blocks
+ (r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
+ (r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
+ (r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
+ (r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
+ (r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
+ (r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
+ (r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
+ r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
+ (r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
+ r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
+ (r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
+]
+
+# Also remap o_proj → out_proj in both encoder and decoder
+_POST_RENAME = [
+ (r"\.o_proj\.", r".out_proj."),
+]
+
+
+def _remap_converted_name(name: str) -> str:
+ """Translate a converted-format weight name to our naming convention."""
+ for pattern, replacement in _CONVERTED_RULES:
+ result, n = re.subn(f"^{pattern}$", replacement, name)
+ if n:
+ name = result
+ break
+ for pattern, replacement in _POST_RENAME:
+ name = re.sub(pattern, replacement, name)
+ return name
+
+
+# ---------------------------------------------------------------------------
+# Loading strategies
+# ---------------------------------------------------------------------------
+
+def _has_converted_layout(path: Path) -> bool:
+ return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
+
+
+def _load_converted_weights(path: Path):
+ with open(path / "config.json") as f:
+ config = json.load(f)
+
+ model = VoxtralMLXModel(config)
+
+ quant = config.get("quantization")
+ if quant is not None:
+ gs = quant["group_size"]
+ nn.quantize(
+ model,
+ group_size=gs,
+ bits=quant["bits"],
+ class_predicate=lambda _p, m: (
+ hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
+ ),
+ )
+
+ index_file = path / "model.safetensors.index.json"
+ if index_file.exists():
+ with open(index_file) as f:
+ shard_map = json.load(f)
+ shard_files = sorted(set(shard_map["weight_map"].values()))
+ weights = {}
+ for sf in shard_files:
+ weights.update(mx.load(str(path / sf)))
+ else:
+ weights = mx.load(str(path / "model.safetensors"))
+
+ remapped = {_remap_converted_name(k): v for k, v in weights.items()}
+ model.load_weights(list(remapped.items()))
+ mx.eval(model.parameters())
+ return model, config
+
+
+def _load_original_weights(path: Path):
+ with open(path / "params.json") as f:
+ config = json.load(f)
+
+ model = VoxtralMLXModel(config)
+
+ raw = mx.load(str(path / "consolidated.safetensors"))
+ mapped: dict[str, mx.array] = {}
+ skipped: list[str] = []
+
+ for name, tensor in raw.items():
+ if name == "output.weight":
+ continue
+ new_name = _translate_weight_name(name)
+ if new_name is None:
+ skipped.append(name)
+ continue
+ # Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
+ if _is_conv_weight(new_name):
+ tensor = mx.swapaxes(tensor, 1, 2)
+ mapped[new_name] = tensor
+
+ if skipped:
+ logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
+
+ model.load_weights(list(mapped.items()))
+ mx.eval(model.parameters())
+ return model, config
+
+
+# ---------------------------------------------------------------------------
+# Tokenizer
+# ---------------------------------------------------------------------------
+
+def _load_tokenizer(model_dir: Path):
+ from mistral_common.tokens.tokenizers.tekken import Tekkenizer
+ return Tekkenizer.from_file(str(model_dir / "tekken.json"))
+
+
+# ---------------------------------------------------------------------------
+# Public API
+# ---------------------------------------------------------------------------
+
+def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
+ """Load a Voxtral Realtime model and its tokenizer.
+
+ Args:
+ path_or_id: Local directory path **or** a HuggingFace model ID.
+
+ Returns:
+ ``(model, tokenizer, config)``
+ """
+ p = Path(path_or_id)
+ if not p.exists():
+ p = download_weights(path_or_id)
+
+ if _has_converted_layout(p):
+ model, config = _load_converted_weights(p)
+ else:
+ model, config = _load_original_weights(p)
+
+ tokenizer = _load_tokenizer(p)
+ logger.info("Voxtral MLX model loaded from %s", p)
+ return model, tokenizer, config
diff --git a/whisperlivekit/voxtral_mlx/model.py b/whisperlivekit/voxtral_mlx/model.py
new file mode 100644
index 0000000..0a637f8
--- /dev/null
+++ b/whisperlivekit/voxtral_mlx/model.py
@@ -0,0 +1,534 @@
+"""
+Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
+
+Architecture:
+ audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
+ with DelayEmbedding providing time-conditioning to the decoder.
+
+The model supports both batch inference (full audio) and incremental streaming
+(one chunk at a time with cached encoder/decoder state).
+"""
+
+import math
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+# ---------------------------------------------------------------------------
+# KV Cache
+# ---------------------------------------------------------------------------
+
+
+class SlidingKVCache:
+ """Bounded key-value cache with rotating buffer for sliding-window attention.
+
+ Uses in-place writes for single-token autoregressive steps and
+ concatenation for multi-token prefills. Pre-allocates in blocks of
+ ``alloc_step`` entries to reduce repeated allocation.
+ """
+
+ alloc_step = 256
+
+ def __init__(self, capacity: int):
+ self.capacity = capacity
+ self.keys = None
+ self.values = None
+ self._offset = 0
+ self._write_idx = 0
+
+ @property
+ def offset(self) -> int:
+ return self._offset
+
+ # -- helpers --
+
+ def _reorder(self, buf):
+ """Return *buf* in temporal order (unwrap the circular buffer)."""
+ if self._write_idx == buf.shape[2]:
+ return buf
+ if self._write_idx < self._offset:
+ return mx.concatenate(
+ [buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
+ axis=2,
+ )
+ return buf[..., : self._write_idx, :]
+
+ def _drop_oldest(self, buf, n_drop, tail=None):
+ parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
+ if tail is not None:
+ parts.append(tail)
+ return mx.concatenate(parts, axis=2)
+
+ # -- update strategies --
+
+ def _append_concat(self, k, v):
+ """Multi-token update via concatenation (used during prefill)."""
+ if self.keys is None:
+ self.keys, self.values = k, v
+ else:
+ self.keys = self._reorder(self.keys)
+ self.values = self._reorder(self.values)
+ self._write_idx = self.keys.shape[2]
+ overflow = self._write_idx - self.capacity + 1
+ self.keys = self._drop_oldest(self.keys, overflow, k)
+ self.values = self._drop_oldest(self.values, overflow, v)
+ self._offset += k.shape[2]
+ self._write_idx = self.keys.shape[2]
+ return self.keys, self.values
+
+ def _write_inplace(self, k, v):
+ """Single-token update via in-place write (autoregressive step)."""
+ B, n_heads, S, dim_k = k.shape
+ dim_v = v.shape[3]
+ prev = self._offset
+
+ if self.keys is None or (
+ prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
+ ):
+ n_new = min(self.alloc_step, self.capacity - prev)
+ fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
+ fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
+ if self.keys is not None:
+ self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
+ self.values = mx.concatenate([self.values, fresh_v], axis=2)
+ else:
+ self.keys, self.values = fresh_k, fresh_v
+ self._write_idx = prev
+
+ overflow = self.keys.shape[2] - self.capacity
+ if overflow > 0:
+ self.keys = self._drop_oldest(self.keys, overflow)
+ self.values = self._drop_oldest(self.values, overflow)
+ self._write_idx = self.capacity
+
+ if self._write_idx == self.capacity:
+ self._write_idx = 0
+
+ self.keys[..., self._write_idx : self._write_idx + S, :] = k
+ self.values[..., self._write_idx : self._write_idx + S, :] = v
+ self._offset += S
+ self._write_idx += S
+
+ if self._offset < self.capacity:
+ return (
+ self.keys[..., : self._offset, :],
+ self.values[..., : self._offset, :],
+ )
+ return self.keys, self.values
+
+ # -- public API --
+
+ def update_and_fetch(self, k, v):
+ if k.shape[2] == 1:
+ return self._write_inplace(k, v)
+ return self._append_concat(k, v)
+
+
+# ---------------------------------------------------------------------------
+# Encoder components
+# ---------------------------------------------------------------------------
+
+
+class CausalConv(nn.Module):
+ """1-D causal convolution (left-padded so no future leakage)."""
+
+ def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
+ super().__init__()
+ self.stride = stride
+ self.kernel = kernel
+ self.left_pad = kernel - stride
+ self.weight = mx.zeros((channels_out, kernel, channels_in))
+ self.bias = mx.zeros((channels_out,))
+
+ def __call__(self, x: mx.array) -> mx.array:
+ if self.left_pad > 0:
+ x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
+ return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
+
+
+class _EncoderSelfAttention(nn.Module):
+ def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
+ super().__init__()
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.scale = head_dim**-0.5
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
+ self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
+ self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
+ self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
+ self.rope_theta = rope_theta
+
+ def __call__(self, x, mask, cache=None):
+ B, L, _ = x.shape
+ q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
+ k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
+ v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
+
+ pos = cache.offset if cache is not None else 0
+ q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
+ k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
+
+ if cache is not None:
+ k, v = cache.update_and_fetch(k, v)
+
+ out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
+ return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
+
+
+class _EncoderFFN(nn.Module):
+ """SwiGLU feed-forward for encoder layers."""
+
+ def __init__(self, dim: int, hidden: int):
+ super().__init__()
+ self.gate = nn.Linear(dim, hidden, bias=False)
+ self.up = nn.Linear(dim, hidden, bias=False)
+ self.down = nn.Linear(hidden, dim, bias=True)
+
+ def __call__(self, x):
+ return self.down(nn.silu(self.gate(x)) * self.up(x))
+
+
+class _EncoderBlock(nn.Module):
+ def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
+ super().__init__()
+ self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
+ self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
+ self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
+ self.ffn = _EncoderFFN(dim, hidden)
+
+ def __call__(self, x, mask, cache=None):
+ x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
+ x = x + self.ffn(self.pre_ffn_norm(x))
+ return x
+
+
+class StreamingEncoder(nn.Module):
+ """Causal Whisper-style encoder with two causal convolutions followed by
+ a stack of transformer blocks. Supports both full-sequence and
+ incremental (streaming) forward passes."""
+
+ def __init__(
+ self,
+ mel_channels: int = 128,
+ dim: int = 1280,
+ n_layers: int = 32,
+ n_heads: int = 32,
+ head_dim: int = 64,
+ hidden_dim: int = 5120,
+ rope_theta: float = 1e6,
+ sliding_window: int = 750,
+ ):
+ super().__init__()
+ self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
+ self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
+ self.blocks = [
+ _EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
+ for _ in range(n_layers)
+ ]
+ self.final_norm = nn.RMSNorm(dim, eps=1e-5)
+ self.sliding_window = sliding_window
+
+ # -- full-sequence --
+
+ def _apply_convs(self, mel: mx.array) -> mx.array:
+ x = mel.T[None, :, :] # [1, T, mel_channels]
+ x = nn.gelu(self.conv1(x))
+ x = nn.gelu(self.conv2(x))
+ return x
+
+ def forward(self, mel: mx.array) -> mx.array:
+ x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
+ for blk in self.blocks:
+ x = blk(x, mask="causal")
+ return self.final_norm(x)
+
+ # -- incremental (streaming) --
+
+ def forward_conv_incremental(self, x_in, tail1, tail2):
+ """Process new mel frames through the two causal convs using cached tails.
+
+ Args:
+ x_in: [1, N, mel_channels]
+ tail1: [1, pad1, mel_channels] or None (first call)
+ tail2: [1, pad2, dim] or None (first call)
+
+ Returns:
+ (out, new_tail1, new_tail2)
+ """
+ # Conv1 (kernel=3, stride=1 → left_pad=2)
+ if tail1 is not None:
+ c1_in = mx.concatenate([tail1, x_in], axis=1)
+ else:
+ c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
+ new_tail1 = x_in[:, -self.conv1.left_pad :, :]
+ c1_out = nn.gelu(
+ mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
+ )
+
+ # Conv2 (kernel=3, stride=2 → left_pad=1)
+ if tail2 is not None:
+ c2_in = mx.concatenate([tail2, c1_out], axis=1)
+ else:
+ c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
+ new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
+ c2_out = nn.gelu(
+ mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
+ )
+
+ return c2_out, new_tail1, new_tail2
+
+ def forward_transformer_incremental(self, x, cache_list):
+ """Run transformer blocks with per-layer KV caches."""
+ for i, blk in enumerate(self.blocks):
+ x = blk(x, mask="causal", cache=cache_list[i])
+ return self.final_norm(x)
+
+
+# ---------------------------------------------------------------------------
+# Decoder components
+# ---------------------------------------------------------------------------
+
+
+class _DecoderAttention(nn.Module):
+ """Grouped-query attention for the text decoder."""
+
+ def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
+ super().__init__()
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.head_dim = head_dim
+ self.scale = head_dim**-0.5
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+ self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+ self.rope_theta = rope_theta
+
+ def __call__(self, x, mask=None, cache=None):
+ B, L, _ = x.shape
+ q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
+ k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
+ v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
+
+ pos = cache.offset if cache is not None else 0
+ q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
+ k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
+
+ if cache is not None:
+ k, v = cache.update_and_fetch(k, v)
+
+ out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
+ return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
+
+
+class _DecoderFFN(nn.Module):
+ """SwiGLU feed-forward for decoder layers."""
+
+ def __init__(self, dim, hidden):
+ super().__init__()
+ self.gate = nn.Linear(dim, hidden, bias=False)
+ self.up = nn.Linear(dim, hidden, bias=False)
+ self.down = nn.Linear(hidden, dim, bias=False)
+
+ def __call__(self, x):
+ return self.down(nn.silu(self.gate(x)) * self.up(x))
+
+
+class AdaptiveScaling(nn.Module):
+ """Small MLP that produces a multiplicative scale from the delay embedding,
+ used to condition the FFN on the streaming delay."""
+
+ def __init__(self, dim, bottleneck):
+ super().__init__()
+ self.proj_in = nn.Linear(dim, bottleneck, bias=False)
+ self.proj_out = nn.Linear(bottleneck, dim, bias=False)
+
+ def __call__(self, cond):
+ return self.proj_out(nn.gelu(self.proj_in(cond)))
+
+
+class _DecoderBlock(nn.Module):
+ def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
+ super().__init__()
+ self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
+ self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
+ self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
+ self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
+ self.ffn = _DecoderFFN(dim, hidden)
+
+ def __call__(self, x, delay_cond, mask=None, cache=None):
+ x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
+ scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
+ x = x + self.ffn(scaled)
+ return x
+
+
+class TextDecoder(nn.Module):
+ """Mistral-style causal language model with adaptive time-conditioning."""
+
+ def __init__(
+ self,
+ dim: int = 3072,
+ n_layers: int = 26,
+ n_heads: int = 32,
+ n_kv_heads: int = 8,
+ head_dim: int = 128,
+ hidden_dim: int = 9216,
+ vocab_size: int = 131072,
+ rope_theta: float = 1e6,
+ cond_dim: int = 32,
+ ):
+ super().__init__()
+ self.token_embedding = nn.Embedding(vocab_size, dim)
+ self.blocks = [
+ _DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
+ for _ in range(n_layers)
+ ]
+ self.final_norm = nn.RMSNorm(dim, eps=1e-5)
+
+ def embed(self, token_ids: mx.array) -> mx.array:
+ return self.token_embedding(token_ids)
+
+ def __call__(self, x, delay_cond, mask=None, cache=None):
+ delay_cond = delay_cond.astype(x.dtype)
+ for i, blk in enumerate(self.blocks):
+ blk_cache = cache[i] if cache is not None else None
+ x = blk(x, delay_cond, mask, blk_cache)
+ x = self.final_norm(x)
+ return self.token_embedding.as_linear(x)
+
+
+# ---------------------------------------------------------------------------
+# Adapter & embeddings
+# ---------------------------------------------------------------------------
+
+
+class EncoderToDecoderAdapter(nn.Module):
+ """Two-layer projection from encoder space to decoder space."""
+
+ def __init__(self, enc_dim: int, dec_dim: int):
+ super().__init__()
+ self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
+ self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
+
+ def __call__(self, x):
+ return self.linear2(nn.gelu(self.linear1(x)))
+
+
+class DelayEmbedding(nn.Module):
+ """Sinusoidal embedding that encodes the streaming delay as a conditioning
+ vector for the decoder's adaptive scaling."""
+
+ def __init__(self, dim: int = 3072, theta: float = 10000.0):
+ super().__init__()
+ self.dim = dim
+ half = dim // 2
+ freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
+ self._freqs = freqs
+
+ def __call__(self, delay: mx.array) -> mx.array:
+ t = delay.reshape(-1, 1).astype(mx.float32)
+ angles = t * self._freqs
+ return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
+
+
+# ---------------------------------------------------------------------------
+# Top-level model
+# ---------------------------------------------------------------------------
+
+
+class VoxtralMLXModel(nn.Module):
+ """Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
+
+ def __init__(self, config: dict):
+ super().__init__()
+
+ enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
+ audio_cfg = enc_cfg["audio_encoding_args"]
+ ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
+
+ self.encoder = StreamingEncoder(
+ mel_channels=audio_cfg["num_mel_bins"],
+ dim=enc_cfg["dim"],
+ n_layers=enc_cfg["n_layers"],
+ n_heads=enc_cfg["n_heads"],
+ head_dim=enc_cfg["head_dim"],
+ hidden_dim=enc_cfg["hidden_dim"],
+ rope_theta=enc_cfg["rope_theta"],
+ sliding_window=enc_cfg["sliding_window"],
+ )
+
+ adapter_input_dim = enc_cfg["dim"] * ds_factor
+ decoder_dim = config["dim"]
+ cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
+
+ self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
+
+ self.decoder = TextDecoder(
+ dim=decoder_dim,
+ n_layers=config["n_layers"],
+ n_heads=config["n_heads"],
+ n_kv_heads=config["n_kv_heads"],
+ head_dim=config["head_dim"],
+ hidden_dim=config["hidden_dim"],
+ vocab_size=config["vocab_size"],
+ rope_theta=config["rope_theta"],
+ cond_dim=cond_bottleneck,
+ )
+
+ self.delay_embedding = DelayEmbedding(dim=decoder_dim)
+ self.ds_factor = ds_factor
+
+ # -- batch encode --
+
+ def encode(self, mel: mx.array) -> mx.array:
+ T = mel.shape[1]
+ if T % 2 != 0:
+ mel = mel[:, 1:]
+
+ h = self.encoder.forward(mel) # [1, T/2, enc_dim]
+ h = h[0]
+
+ n = h.shape[0]
+ trim = n % self.ds_factor
+ if trim:
+ h = h[trim:]
+ n = h.shape[0]
+
+ h = h.reshape(n // self.ds_factor, -1)
+ return self.adapter(h)
+
+ # -- incremental encode --
+
+ def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
+ """Incrementally encode new mel frames.
+
+ Returns:
+ (audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
+ """
+ x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
+
+ x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
+
+ if enc_cache is None:
+ enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
+
+ x = self.encoder.forward_transformer_incremental(x, enc_cache)
+ x = x[0] # [N, enc_dim]
+
+ if ds_remainder is not None:
+ x = mx.concatenate([ds_remainder, x])
+
+ n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
+ if n_full == 0:
+ return None, conv_tail1, conv_tail2, enc_cache, x
+
+ leftover = x[n_full:] if x.shape[0] > n_full else None
+ x = x[:n_full].reshape(n_full // self.ds_factor, -1)
+ return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
+
+ # -- decode --
+
+ def decode(self, embeddings, delay_cond, mask=None, cache=None):
+ return self.decoder(embeddings, delay_cond, mask, cache)
diff --git a/whisperlivekit/voxtral_mlx/spectrogram.py b/whisperlivekit/voxtral_mlx/spectrogram.py
new file mode 100644
index 0000000..0fdf463
--- /dev/null
+++ b/whisperlivekit/voxtral_mlx/spectrogram.py
@@ -0,0 +1,202 @@
+"""
+Mel spectrogram computation for Voxtral Realtime.
+
+Provides both a full-audio function and an incremental streaming variant
+that maintains overlap state between calls. The DFT is computed via
+matrix multiplication in MLX — no external FFT dependency required.
+"""
+
+import math
+
+import mlx.core as mx
+import numpy as np
+
+# Audio / mel constants matching the Voxtral Realtime model expectations.
+SAMPLE_RATE = 16_000
+WINDOW_SIZE = 400 # n_fft
+HOP = 160
+MEL_BANDS = 128
+MEL_MAX = 1.5 # global log-mel normalisation ceiling
+# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
+SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
+
+# Padding tokens used by the model prompt structure.
+LEFT_PAD_TOKENS = 32
+RIGHT_PAD_TOKENS = 17
+
+
+# ---------------------------------------------------------------------------
+# Slaney mel filterbank
+# ---------------------------------------------------------------------------
+
+def _build_slaney_filterbank(
+ sr: int = SAMPLE_RATE,
+ n_fft: int = WINDOW_SIZE,
+ n_mels: int = MEL_BANDS,
+ lo_hz: float = 0.0,
+ hi_hz: float = 8000.0,
+) -> np.ndarray:
+ """Compute a Slaney-normalised triangular mel filterbank.
+
+ Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
+ """
+
+ def _hz2mel(f):
+ threshold = 1000.0
+ base_mel = 15.0
+ log_coeff = 27.0 / np.log(6.4)
+ mel = 3.0 * f / 200.0
+ if isinstance(f, np.ndarray):
+ above = f >= threshold
+ mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
+ elif f >= threshold:
+ mel = base_mel + np.log(f / threshold) * log_coeff
+ return mel
+
+ def _mel2hz(m):
+ threshold = 1000.0
+ base_mel = 15.0
+ log_coeff = np.log(6.4) / 27.0
+ hz = 200.0 * m / 3.0
+ above = m >= base_mel
+ hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
+ return hz
+
+ n_bins = n_fft // 2 + 1
+ fft_hz = np.linspace(0, sr / 2, n_bins)
+ mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
+ mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
+ hz_pts = _mel2hz(mel_pts)
+ diffs = np.diff(hz_pts)
+
+ slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
+ rising = -slopes[:, :-2] / diffs[:-1]
+ falling = slopes[:, 2:] / diffs[1:]
+ fb = np.maximum(0.0, np.minimum(rising, falling))
+
+ # Slaney area normalisation
+ widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
+ fb *= np.expand_dims(widths, 0)
+ return fb.T.astype(np.float32)
+
+
+_CACHED_FILTERS: mx.array | None = None
+
+
+def _mel_filters() -> mx.array:
+ global _CACHED_FILTERS
+ if _CACHED_FILTERS is None:
+ _CACHED_FILTERS = mx.array(_build_slaney_filterbank())
+ return _CACHED_FILTERS
+
+
+# ---------------------------------------------------------------------------
+# DFT helpers
+# ---------------------------------------------------------------------------
+
+def _hann_window() -> mx.array:
+ return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
+
+
+def _dft_matrices():
+ """Pre-compute the real / imaginary DFT basis matrices."""
+ n_bins = WINDOW_SIZE // 2 + 1
+ k = mx.arange(n_bins, dtype=mx.float32)[:, None]
+ n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
+ phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
+ return mx.cos(phase), mx.sin(phase)
+
+
+def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
+ """Frame *audio* using the Hann window and compute power spectrogram."""
+ n_bins = WINDOW_SIZE // 2 + 1
+ n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
+ if n_frames <= 0:
+ return mx.zeros((0, n_bins))
+
+ offsets = (mx.arange(n_frames) * HOP)[:, None]
+ indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
+ windowed = audio[indices] * window[None, :]
+
+ dft_re, dft_im = _dft_matrices()
+ real_part = windowed @ dft_re.T
+ imag_part = windowed @ dft_im.T
+ return real_part ** 2 + imag_part ** 2
+
+
+def _apply_mel_and_log(power: mx.array) -> mx.array:
+ """Convert a power spectrogram to log-mel and normalise."""
+ mel = power @ _mel_filters().T
+ log_mel = mx.log10(mx.maximum(mel, 1e-10))
+ log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
+ return (log_mel + 4.0) / 4.0
+
+
+# ---------------------------------------------------------------------------
+# Public API
+# ---------------------------------------------------------------------------
+
+def compute_mel(audio: np.ndarray) -> mx.array:
+ """Compute log-mel spectrogram for a complete audio signal.
+
+ Args:
+ audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
+
+ Returns:
+ ``[MEL_BANDS, T]`` MLX array.
+ """
+ x = mx.array(audio)
+ pad = WINDOW_SIZE // 2
+ x = mx.pad(x, [(pad, pad)])
+ window = _hann_window()
+
+ power = _stft_frames(x, window)
+ # Drop last frame to match reference STFT behaviour
+ power = power[:-1]
+ return _apply_mel_and_log(power).T
+
+
+def compute_mel_streaming(
+ chunk: np.ndarray,
+ overlap: np.ndarray | None,
+) -> tuple[mx.array, np.ndarray]:
+ """Incrementally compute log-mel for a new audio chunk.
+
+ Args:
+ chunk: New audio samples (float32 numpy).
+ overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
+ previous call, or *None* on the first call (uses zero-padding).
+
+ Returns:
+ ``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
+ *new_overlap* is the 240-sample tail for the next call.
+ """
+ tail_len = WINDOW_SIZE - HOP # 240
+
+ if overlap is not None:
+ combined = np.concatenate([overlap, chunk])
+ else:
+ combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
+
+ new_overlap = combined[-tail_len:].copy()
+
+ x = mx.array(combined)
+ window = _hann_window()
+ power = _stft_frames(x, window)
+
+ if power.shape[0] == 0:
+ return mx.zeros((MEL_BANDS, 0)), new_overlap
+
+ return _apply_mel_and_log(power).T, new_overlap
+
+
+def pad_audio(
+ audio: np.ndarray,
+ n_left: int = LEFT_PAD_TOKENS,
+ n_right: int = RIGHT_PAD_TOKENS,
+) -> np.ndarray:
+ """Pad audio with silence for batch (non-streaming) inference."""
+ left = n_left * SAMPLES_PER_TOKEN
+ align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
+ right = align + n_right * SAMPLES_PER_TOKEN
+ return np.pad(audio, (left, right))
diff --git a/whisperlivekit/voxtral_mlx_asr.py b/whisperlivekit/voxtral_mlx_asr.py
new file mode 100644
index 0000000..4c62f80
--- /dev/null
+++ b/whisperlivekit/voxtral_mlx_asr.py
@@ -0,0 +1,521 @@
+"""
+Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
+
+Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
+(streaming processor) that plug into WhisperLiveKit's audio processing
+pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
+
+Unlike the HuggingFace backend, this runs the full inference loop in-process
+(no background thread / queue) — MLX operations on Apple Silicon are fast
+enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
+"""
+
+import logging
+import sys
+import time
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import numpy as np
+from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
+
+from whisperlivekit.timed_objects import ASRToken, Transcript
+from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
+from whisperlivekit.voxtral_mlx.model import SlidingKVCache
+from whisperlivekit.voxtral_mlx.spectrogram import (
+ SAMPLES_PER_TOKEN,
+ LEFT_PAD_TOKENS,
+ RIGHT_PAD_TOKENS,
+ compute_mel_streaming,
+)
+
+logger = logging.getLogger(__name__)
+
+# Decoder sliding-window size (matches the model's training configuration).
+_DECODER_WINDOW = 8192
+
+
+def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
+ """Build the prompt token sequence and return ``(token_ids, n_delay)``."""
+ pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
+ ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
+ return ids, n_delay
+
+
+# ---------------------------------------------------------------------------
+# Model holder
+# ---------------------------------------------------------------------------
+
+
+class VoxtralMLXASR:
+ """Lightweight model holder — loads the MLX Voxtral model once and keeps
+ it alive for the lifetime of the server."""
+
+ sep = " "
+ SAMPLING_RATE = 16_000
+
+ def __init__(self, logfile=sys.stderr, **kwargs):
+ self.logfile = logfile
+ self.transcribe_kargs = {}
+
+ lan = kwargs.get("lan", "auto")
+ self.original_language = None if lan == "auto" else lan
+
+ model_path = kwargs.get("model_dir") or kwargs.get("model_path")
+ if not model_path:
+ model_size = kwargs.get("model_size", "")
+ if model_size and ("/" in model_size or model_size.startswith(".")):
+ model_path = model_size
+ else:
+ model_path = DEFAULT_MODEL_ID
+
+ t0 = time.time()
+ logger.info("Loading Voxtral MLX model '%s' ...", model_path)
+ self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
+ logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
+
+ self.backend_choice = "voxtral-mlx"
+
+ def transcribe(self, audio):
+ pass # all work happens in the online processor
+
+
+# ---------------------------------------------------------------------------
+# Online processor
+# ---------------------------------------------------------------------------
+
+
+class VoxtralMLXOnlineProcessor:
+ """Streaming processor that incrementally encodes audio and decodes text
+ using the MLX Voxtral model.
+
+ Lifecycle (called by ``AudioProcessor.transcription_processor``):
+
+ insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
+ ... repeat ...
+ start_silence() / end_silence()
+ finish()
+ """
+
+ SAMPLING_RATE = 16_000
+
+ def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
+ self.asr = asr
+ self.logfile = logfile
+ self.end = 0.0
+ self.buffer: list = []
+ self.audio_buffer = np.array([], dtype=np.float32)
+
+ self._model = asr.model
+ self._tokenizer = asr.tokenizer
+
+ # Pre-compute prompt tokens and delay conditioning (constant across utterances).
+ self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
+ self._prefix_len = len(self._prompt_ids)
+
+ self._delay_cond = self._model.delay_embedding(
+ mx.array([self._n_delay], dtype=mx.float32)
+ )
+ mx.eval(self._delay_cond)
+
+ self._prompt_embeds = self._model.decoder.embed(
+ mx.array([self._prompt_ids])
+ )[0] # [prefix_len, dim]
+ mx.eval(self._prompt_embeds)
+
+ self._eos_id = self._tokenizer.eos_id
+ self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
+ # The streaming model has an inherent delay: text for audio at position P
+ # is generated at decoder position P + n_delay. Compensate timestamps.
+ self._delay_secs = self._n_delay * self._secs_per_token
+
+ self._reset_state()
+
+ # -- state management --
+
+ def _reset_state(self):
+ """Reset all incremental state for a fresh utterance."""
+ # Audio accumulation
+ self._pending = np.zeros(0, dtype=np.float32)
+ # Mel overlap
+ self._mel_overlap: np.ndarray | None = None
+ # Encoder incremental state
+ self._conv_tail1 = None
+ self._conv_tail2 = None
+ self._enc_cache = None
+ self._ds_remainder = None
+ # Audio embeddings not yet decoded
+ self._audio_embeds: mx.array | None = None
+ # Decoder state
+ self._dec_cache: list[SlidingKVCache] | None = None
+ self._last_token: mx.array | None = None
+ # Bookkeeping
+ self._samples_encoded = 0
+ self._positions_decoded = 0
+ self._prefilled = False
+ self._first_chunk = True
+ # Text state
+ self._full_text = ""
+ self._n_text_tokens = 0
+ self._n_committed_words = 0
+ self._time_offset = 0.0
+ # Per-word audio position tracking: decoder position (relative to prefix)
+ # where each word in _full_text started and ended
+ self._word_audio_starts: list[int] = [] # audio pos where word i started
+ self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
+ self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
+
+ # -- audio ingestion --
+
+ def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
+ self.end = audio_stream_end_time
+ self._pending = np.append(self._pending, audio)
+ self.audio_buffer = self._pending
+
+ # -- core processing --
+
+ def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
+ try:
+ return self._step(is_last)
+ except Exception as e:
+ logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
+ return [], self.end
+
+ def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
+ # 1. Encode any new audio
+ self._encode_pending()
+
+ if self._audio_embeds is None:
+ return [], self.end
+
+ # 2. Compute how many positions we can safely decode
+ total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
+ n_available = self._audio_embeds.shape[0]
+ n_decodable = min(n_available, total_safe - self._positions_decoded)
+
+ if n_decodable <= 0:
+ return [], self.end
+
+ # 3. Prefill if needed
+ if not self._prefilled:
+ if self._positions_decoded + n_available < self._prefix_len:
+ return [], self.end
+ self._do_prefill()
+ # Re-check after consuming prefix embeddings
+ n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
+ n_decodable = min(n_available, total_safe - self._positions_decoded)
+
+ if n_decodable <= 0 or self._audio_embeds is None:
+ return [], self.end
+
+ # 4. Decode available positions
+ hit_eos = self._decode_positions(n_decodable)
+
+ if hit_eos:
+ # Flush words, reset for next utterance
+ words = self._flush_all_words()
+ logger.debug(
+ "[voxtral-mlx] EOS hit during stream: flushed %d words, "
+ "samples_encoded=%d (%.2fs), text='%s'",
+ len(words), self._samples_encoded,
+ self._samples_encoded / self.SAMPLING_RATE,
+ self._full_text[-60:] if self._full_text else "",
+ )
+ saved_offset = self._time_offset
+ self._reset_state()
+ self._time_offset = saved_offset
+ return words, self.end
+
+ # 5. Extract committed words (all but the last, which may still grow)
+ return self._extract_committed_words(), self.end
+
+ def _encode_pending(self):
+ """Feed pending audio through the incremental encoder."""
+ available = len(self._pending)
+ if available < SAMPLES_PER_TOKEN:
+ return
+
+ if self._first_chunk:
+ # First chunk: prepend silence for left-padding
+ n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
+ left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
+ chunk = np.concatenate([left_pad, self._pending[:n_take]])
+ self._pending = self._pending[n_take:]
+ self._samples_encoded += n_take
+ self._first_chunk = False
+ else:
+ n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
+ chunk = self._pending[:n_take]
+ self._pending = self._pending[n_take:]
+ self._samples_encoded += n_take
+
+ mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
+
+ embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
+ self._model.encode_incremental(
+ mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
+ )
+ )
+
+ if embeds is not None:
+ mx.eval(embeds)
+ if self._audio_embeds is not None:
+ self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
+ else:
+ self._audio_embeds = embeds
+
+ self.audio_buffer = self._pending
+
+ def _do_prefill(self):
+ """Run the decoder prefill pass over the prompt + first audio embeddings."""
+ n_dec_layers = len(self._model.decoder.blocks)
+ self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
+
+ prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
+ prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
+
+ logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
+ mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
+
+ self._last_token = self._sample(logits)
+ mx.async_eval(self._last_token)
+
+ # Remove consumed prefix embeddings
+ self._audio_embeds = self._audio_embeds[self._prefix_len :]
+ if self._audio_embeds.shape[0] == 0:
+ self._audio_embeds = None
+ self._positions_decoded = self._prefix_len
+ self._prefilled = True
+
+ def _decode_positions(self, n: int) -> bool:
+ """Autoregressively decode *n* positions. Returns True on EOS."""
+ base_pos = self._positions_decoded # absolute position before this batch
+ for i in range(n):
+ tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
+ combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
+ logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
+ next_tok = self._sample(logits)
+ mx.async_eval(next_tok)
+
+ token_id = self._last_token.item()
+ if token_id == self._eos_id:
+ # Close the current word if one is being built
+ if self._current_word_pos is not None:
+ self._word_audio_ends.append(base_pos + i - self._prefix_len)
+ self._current_word_pos = None
+ self._trim_embeds(i)
+ self._positions_decoded += i
+ return True
+
+ text = self._tokenizer.decode(
+ [token_id], special_token_policy=SpecialTokenPolicy.IGNORE
+ )
+
+ if text:
+ audio_pos = base_pos + i - self._prefix_len
+
+ # Detect word boundary: new word starts with space or is the very first text
+ if text.lstrip() != text or not self._full_text:
+ # Close previous word if exists
+ if self._current_word_pos is not None:
+ self._word_audio_ends.append(audio_pos)
+ # Start new word
+ self._word_audio_starts.append(audio_pos)
+ self._current_word_pos = audio_pos
+ elif self._current_word_pos is None:
+ # First token of first word (no leading space)
+ self._word_audio_starts.append(audio_pos)
+ self._current_word_pos = audio_pos
+
+ self._full_text += text
+ self._n_text_tokens += 1
+
+ if i > 0 and i % 256 == 0:
+ mx.clear_cache()
+
+ self._last_token = next_tok
+
+ self._positions_decoded += n
+ self._trim_embeds(n)
+ return False
+
+ def _trim_embeds(self, n_consumed: int):
+ if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
+ self._audio_embeds = self._audio_embeds[n_consumed:]
+ else:
+ self._audio_embeds = None
+
+ def _sample(self, logits: mx.array) -> mx.array:
+ return mx.argmax(logits[0, -1:], axis=-1).squeeze()
+
+ # -- word extraction --
+
+ def _audio_pos_to_time(self, pos: int) -> float:
+ """Convert an audio position (relative to prefix end) to seconds."""
+ return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
+
+ def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
+ """Compute (start, end) time for a word using tracked word positions."""
+ starts = self._word_audio_starts
+ ends = self._word_audio_ends
+
+ if not starts:
+ return self._time_offset, self._time_offset
+
+ # Get start position for this word
+ if word_idx < len(starts):
+ t0 = self._audio_pos_to_time(starts[word_idx])
+ else:
+ # Fallback: estimate from last known position
+ last_pos = ends[-1] if ends else starts[-1]
+ t0 = self._audio_pos_to_time(last_pos + 1)
+
+ # Get end position: use the start of the next word, or the end of this word
+ if word_idx + 1 < len(starts):
+ t1 = self._audio_pos_to_time(starts[word_idx + 1])
+ elif word_idx < len(ends):
+ t1 = self._audio_pos_to_time(ends[word_idx] + 1)
+ else:
+ # Last word, still being built: use last known position + 1 token
+ last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
+ t1 = self._audio_pos_to_time(last_pos + 1)
+
+ return t0, t1
+
+ def _extract_committed_words(self) -> List[ASRToken]:
+ """Return complete words (all except the last which may still grow)."""
+ if not self._full_text:
+ return []
+ words = self._full_text.split()
+ tokens: List[ASRToken] = []
+ n_total = max(len(words), 1)
+
+ while len(words) > self._n_committed_words + 1:
+ w = words[self._n_committed_words]
+ idx = self._n_committed_words
+ t0, t1 = self._word_time_range(idx, n_total)
+ label = w if idx == 0 else " " + w
+ tokens.append(ASRToken(start=t0, end=t1, text=label))
+ self._n_committed_words += 1
+
+ return tokens
+
+ def _flush_all_words(self) -> List[ASRToken]:
+ """Flush every word including the last partial one."""
+ if not self._full_text:
+ return []
+ words = self._full_text.split()
+ tokens: List[ASRToken] = []
+ n_total = max(len(words), 1)
+
+ while self._n_committed_words < len(words):
+ w = words[self._n_committed_words]
+ idx = self._n_committed_words
+ t0, t1 = self._word_time_range(idx, n_total)
+ label = w if idx == 0 else " " + w
+ tokens.append(ASRToken(start=t0, end=t1, text=label))
+ self._n_committed_words += 1
+
+ return tokens
+
+ # -- interface methods --
+
+ def get_buffer(self) -> Transcript:
+ if not self._full_text:
+ return Transcript(start=None, end=None, text="")
+ words = self._full_text.split()
+ remaining = words[self._n_committed_words :]
+ if remaining:
+ return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
+ return Transcript(start=None, end=None, text="")
+
+ def start_silence(self) -> Tuple[List[ASRToken], float]:
+ words = self._flush_all_words()
+ logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
+ return words, self.end
+
+ def end_silence(self, silence_duration: float, offset: float):
+ self._time_offset += silence_duration
+ self.end += silence_duration
+
+ def new_speaker(self, change_speaker):
+ self.start_silence()
+
+ def warmup(self, audio, init_prompt=""):
+ pass
+
+ def finish(self) -> Tuple[List[ASRToken], float]:
+ logger.debug(
+ "[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
+ "samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
+ len(self._pending),
+ self._audio_embeds.shape if self._audio_embeds is not None else None,
+ self._samples_encoded,
+ self._positions_decoded,
+ self._prefilled,
+ self._full_text[-80:] if self._full_text else "",
+ )
+
+ # Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
+ remainder = len(self._pending) % SAMPLES_PER_TOKEN
+ if remainder > 0:
+ align_pad = SAMPLES_PER_TOKEN - remainder
+ else:
+ align_pad = 0
+
+ # Add alignment + right-padding silence
+ total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
+ if total_pad > 0:
+ self._pending = np.append(
+ self._pending, np.zeros(total_pad, dtype=np.float32)
+ )
+
+ # Encode remaining audio (including right-padding)
+ self._encode_pending()
+
+ logger.debug(
+ "[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
+ self._audio_embeds.shape if self._audio_embeds is not None else None,
+ len(self._pending),
+ )
+
+ hit_eos = False
+
+ # Decode everything that's left from right-padding
+ if self._audio_embeds is not None and self._prefilled:
+ hit_eos = self._decode_positions(self._audio_embeds.shape[0])
+ logger.debug(
+ "[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
+ hit_eos, self._full_text[-80:] if self._full_text else "",
+ )
+
+ # Flush last token if it wasn't EOS
+ if self._last_token is not None:
+ tid = self._last_token.item()
+ if tid != self._eos_id:
+ text = self._tokenizer.decode(
+ [tid], special_token_policy=SpecialTokenPolicy.IGNORE
+ )
+ if text:
+ last_pos = self._positions_decoded - self._prefix_len
+ # Check if this starts a new word
+ if text.lstrip() != text or not self._full_text:
+ if self._current_word_pos is not None:
+ self._word_audio_ends.append(last_pos)
+ self._word_audio_starts.append(last_pos)
+ self._current_word_pos = last_pos
+ elif self._current_word_pos is None:
+ self._word_audio_starts.append(last_pos)
+ self._current_word_pos = last_pos
+ self._full_text += text
+ self._n_text_tokens += 1
+
+ # Close the last word if still open
+ if self._current_word_pos is not None:
+ last_pos = self._positions_decoded - self._prefix_len
+ self._word_audio_ends.append(last_pos)
+ self._current_word_pos = None
+
+ words = self._flush_all_words()
+ logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
+ return words, self.end