This commit is contained in:
Quentin Fuxa
2026-03-04 18:17:23 +01:00
38 changed files with 12501 additions and 164 deletions

13
.dockerignore Normal file
View File

@@ -0,0 +1,13 @@
.git
.github
.venv
__pycache__
*.pyc
.pytest_cache
.mypy_cache
.ruff_cache
.cache
.tmp
.secrets
dist
build

61
.github/workflows/publish-docker.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: Publish Docker Images
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
tag:
description: "Image tag to publish (without image suffix)"
required: true
type: string
permissions:
contents: read
packages: write
jobs:
docker:
runs-on: ubuntu-latest
env:
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
strategy:
fail-fast: false
matrix:
include:
- image_suffix: cpu-diarization-sortformer
dockerfile: Dockerfile.cpu
extras: cpu,diarization-sortformer
- image_suffix: cu129-diarization-sortformer
dockerfile: Dockerfile
extras: cu129,diarization-sortformer
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set lowercase owner
id: owner
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
file: ./${{ matrix.dockerfile }}
push: true
build-args: |
EXTRAS=${{ matrix.extras }}
tags: |
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}

6
.gitignore vendored
View File

@@ -119,9 +119,11 @@ run_*.sh
*.pt
# Debug & testing
test_*.py
/test_*.py
!test_backend_offline.py
launch.json
.DS_Store
test/*
/test/
!tests/
nllb-200-distilled-600M-ctranslate2/*
*.mp3

205
BENCHMARK.md Normal file
View File

@@ -0,0 +1,205 @@
# WhisperLiveKit Benchmark Report
Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon.
All tests run through the full AudioProcessor pipeline (same code path as production WebSocket).
## Test Environment
| Property | Value |
|----------|-------|
| Hardware | Apple M4, 32 GB RAM |
| OS | macOS 25.3.0 (arm64) |
| Python | 3.13 |
| faster-whisper | 1.2.1 |
| mlx-whisper | installed (via mlx) |
| Voxtral MLX | native MLX backend |
| Voxtral (HF) | transformers-based |
| VAC (Silero VAD) | enabled unless noted |
| Chunk size | 100 ms |
| Pacing | no-realtime (as fast as possible) |
## Audio Test Files
| File | Duration | Language | Speakers | Description |
|------|----------|----------|----------|-------------|
| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses |
| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps |
| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation |
Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified.
---
## Results
### English -- Short (7.2 s, 1 speaker)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s |
| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s |
| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s |
| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s |
| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s |
| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s |
| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s |
| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s |
| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s |
| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s |
### English -- Multi-speaker (30.0 s, 3 speakers)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s |
| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s |
| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s |
| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s |
| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s |
| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s |
| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s |
| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s |
| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s |
| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s |
<p align="center">
<img src="benchmark_chart.png" alt="Benchmark comparison on 30s English" width="800">
</p>
<p align="center">
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
</p>
### French (16.3 s, 1 speaker, `--language fr`)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s |
| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s |
| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s |
| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s |
| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* |
| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s |
| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s |
| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s |
| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s |
| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s |
\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem.
**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps.
---
## Model Size Comparison (base vs small)
| | base | small | Observation |
|--|------|-------|-------------|
| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower |
| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base |
| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio |
| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo |
| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps |
In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages.
---
## Key Findings
### Speed (RTF = processing time / audio duration, lower is better)
1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds.
2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed.
3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time.
4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead.
5. The **small** model is 2-3x slower than base across all backends.
### Accuracy (WER = Word Error Rate, lower is better)
1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%.
2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments.
3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run.
4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER.
### Timestamps (MAE = Mean Absolute Error on word start times)
1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE).
2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications.
3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`.
4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file).
### VAC (Voice Activity Classification) Impact
| Backend | Policy | VAC | 7s English WER | 30s English WER |
|---------|--------|-----|----------------|-----------------|
| faster-whisper | LocalAgreement | on | 21.1% | 44.7% |
| faster-whisper | LocalAgreement | off | 100.0% | 100.0% |
| voxtral-mlx | voxtral | on | 0.0% | 9.2% |
| voxtral-mlx | voxtral | off | 0.0% | 9.2% |
- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output.
- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments.
---
## Recommendations
| Use Case | Backend | Policy | Model | Notes |
|----------|---------|--------|-------|-------|
| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER |
| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER |
| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast |
| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF |
| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles |
| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response |
---
## Caveats
- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions.
- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine.
- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU.
---
## Reproducing These Benchmarks
```bash
# Install test dependencies
pip install -e ".[test]"
# Single backend test
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime
# With a specific language
python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime
# Multi-backend auto-detect benchmark
python test_backend_offline.py --benchmark --no-realtime
# Export to JSON
python test_backend_offline.py --benchmark --no-realtime --json results.json
# Test with your own audio
python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime
```
The benchmark harness computes WER and timestamp accuracy automatically when ground truth
`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format.
---
## Help Us Benchmark on More Hardware
These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc.
If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get.
What we are especially interested in:
- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper
- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx
- **Medium and large-v3 models** (we only tested base and small so far)
- **Longer audio files** or domain-specific audio (medical, legal, call center)
- **Other languages** beyond English and French

View File

@@ -1,86 +1,74 @@
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install UV and set up the environment
COPY --from=uvbin /uv /uvx /bin/
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
RUN uv python install 3.12
# Install dependencies first to leverage caching
ARG EXTRAS=cu129
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Copy the source code and install the package only
COPY whisperlivekit /app/whisperlivekit
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-editable --no-cache "$@"
# --- MARK: Runtime Stage
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && \
apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-venv \
ffmpeg \
git \
build-essential \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
apt-get install -y --no-install-recommends \
ffmpeg &&\
rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy UV binaries
COPY --from=uvbin /uv /uvx /bin/
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
# Copy the Python version
COPY --from=builder-gpu --chown=python:python /python /python
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# In-container caching for Hugging Face models by:
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
# WARNING: This will copy ALL files in the pre-cache location.
# Conditionally copy a cache directory if provided
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Copy the virtual environment with all dependencies installed
COPY --from=builder-gpu /app/.venv /app/.venv
EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]

View File

@@ -1,64 +1,76 @@
FROM python:3.13-slim
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM debian:bookworm-slim AS builder-cpu
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install UV and set up the environment
COPY --from=uvbin /uv /uvx /bin/
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
RUN uv python install 3.12
# Install dependencies first to leverage caching
ARG EXTRAS=cpu
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Copy the source code and install the package only
COPY whisperlivekit /app/whisperlivekit
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-editable --no-cache "$@"
# --- MARK: Runtime Stage
FROM debian:bookworm-slim
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
apt-get install -y --no-install-recommends \
ffmpeg &&\
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# Copy UV binaries
COPY --from=uvbin /uv /uvx /bin/
COPY . .
# Copy the Python version
COPY --from=builder-cpu --chown=python:python /python /python
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Copy the virtual environment with all dependencies installed
COPY --from=builder-cpu /app/.venv /app/.venv
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"]
CMD ["--model", "tiny"]

115
README.md
View File

@@ -25,6 +25,7 @@
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
- [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) (2025) - 4B-parameter multilingual speech model by Mistral AI
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
@@ -71,19 +72,61 @@ Go to `chrome-extension` for instructions.
#### Optional Dependencies
| Optional | `pip install` |
|-----------|-------------|
| **Windows/Linux optimizations** | `faster-whisper` |
| **Apple Silicon optimizations** | `mlx-whisper` |
| **Translation** | `nllw` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| OpenAI API | `openai` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| Feature | `uv sync` | `pip install -e` |
|-----------|-------------|-------------|
| **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
| **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
| **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
| **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
| **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
| **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
| **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
See **Parameters & Configuration** below on how to use them.
Supported GPU profiles:
```bash
# Profile A: Sortformer diarization
uv sync --extra cu129 --extra diarization-sortformer
# Profile B: Voxtral HF + translation
uv sync --extra cu129 --extra voxtral-hf --extra translation
```
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
See **Parameters & Configuration** below on how to use them.
<p align="center">
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
</p>
See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more.
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
### Voxtral Backend
WhisperLiveKit supports [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602),
a 4B-parameter speech model from Mistral AI that natively handles 100+ languages with automatic
language detection. Whisper also supports auto-detection (`--language auto`), but Voxtral's per-chunk
detection is more reliable and does not bias towards English.
```bash
# Apple Silicon (native MLX, recommended)
pip install -e ".[voxtral-mlx]"
wlk --backend voxtral-mlx
# Linux/GPU (HuggingFace transformers)
pip install transformers torch
wlk --backend voxtral
```
Voxtral uses its own streaming policy and does not use LocalAgreement or SimulStreaming.
See [BENCHMARK.md](BENCHMARK.md) for performance numbers.
### Usage Examples
**Command-line Interface**: Start the transcription server with various options:
@@ -92,8 +135,11 @@ See **Parameters & Configuration** below on how to use them.
# Large model and translate from french to danish
wlk --model large-v3 --language fr --target-language da
# Diarization and server listening on */80
# Diarization and server listening on */80
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
# Voxtral multilingual (auto-detects language)
wlk --backend voxtral-mlx
```
@@ -151,7 +197,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
| `--backend` | ASR backend selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. Options: `mlx-whisper`, `faster-whisper`, `whisper`, `openai-api` (LocalAgreement only), `voxtral-mlx` (Apple Silicon), `voxtral` (HuggingFace) | `auto` |
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
@@ -248,7 +294,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk
**CPU only:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
docker run -p 8000:8000 --name wlk wlk
```
@@ -260,6 +306,18 @@ docker run -p 8000:8000 --name wlk wlk
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
**Compose (recommended for cache + token wiring):**
```bash
# GPU Sortformer profile
docker compose up --build wlk-gpu-sortformer
# GPU Voxtral profile
docker compose up --build wlk-gpu-voxtral
# CPU service
docker compose up --build wlk-cpu
```
### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
@@ -267,9 +325,34 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization
- `--build-arg` Options:
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
- `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
- `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
- `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
## 🔮 Use Cases
## Testing & Benchmarks
WhisperLiveKit includes a unit test suite and an offline benchmark harness.
```bash
# Install test dependencies
pip install -e ".[test]"
# Run unit tests (no model download required)
pytest tests/ -v
# Benchmark a single backend
python test_backend_offline.py --backend faster-whisper --no-realtime
# Benchmark all installed backends
python test_backend_offline.py --benchmark --no-realtime
# Export benchmark results as JSON
python test_backend_offline.py --benchmark --no-realtime --json results.json
```
See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and
timestamp accuracy on Apple Silicon.
## Use Cases
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...

View File

@@ -0,0 +1,97 @@
[
{
"word": "This",
"start": 0.0,
"end": 0.24
},
{
"word": "is",
"start": 0.24,
"end": 0.56
},
{
"word": "a",
"start": 0.56,
"end": 0.76
},
{
"word": "transcription",
"start": 0.76,
"end": 1.32
},
{
"word": "test.",
"start": 1.32,
"end": 2.0
},
{
"word": "We",
"start": 2.4,
"end": 2.5
},
{
"word": "want",
"start": 2.5,
"end": 2.66
},
{
"word": "to",
"start": 2.66,
"end": 2.84
},
{
"word": "see",
"start": 2.84,
"end": 3.1
},
{
"word": "if",
"start": 3.1,
"end": 3.34
},
{
"word": "we",
"start": 3.34,
"end": 3.5
},
{
"word": "can",
"start": 3.5,
"end": 3.68
},
{
"word": "use",
"start": 3.68,
"end": 4.04
},
{
"word": "smaller",
"start": 4.04,
"end": 4.76
},
{
"word": "chunks.",
"start": 4.76,
"end": 5.16
},
{
"word": "What",
"start": 6.06,
"end": 6.32
},
{
"word": "do",
"start": 6.32,
"end": 6.44
},
{
"word": "you",
"start": 6.44,
"end": 6.58
},
{
"word": "think?",
"start": 6.58,
"end": 6.84
}
]

View File

@@ -0,0 +1,177 @@
[
{
"word": "Ok,",
"start": 2.02,
"end": 2.38
},
{
"word": "là",
"start": 2.52,
"end": 2.58
},
{
"word": "c",
"start": 2.58,
"end": 2.74
},
{
"word": "'est",
"start": 2.74,
"end": 2.76
},
{
"word": "un",
"start": 2.76,
"end": 2.86
},
{
"word": "test,",
"start": 2.86,
"end": 3.2
},
{
"word": "on",
"start": 3.34,
"end": 3.34
},
{
"word": "veut",
"start": 3.34,
"end": 3.48
},
{
"word": "voir",
"start": 3.48,
"end": 3.86
},
{
"word": "si",
"start": 3.86,
"end": 4.14
},
{
"word": "ça",
"start": 4.14,
"end": 4.26
},
{
"word": "arrive",
"start": 4.26,
"end": 4.36
},
{
"word": "à",
"start": 4.36,
"end": 4.5
},
{
"word": "capté",
"start": 4.5,
"end": 4.78
},
{
"word": "le",
"start": 4.78,
"end": 4.9
},
{
"word": "silence.",
"start": 4.9,
"end": 5.44
},
{
"word": "Là",
"start": 9.24,
"end": 9.6
},
{
"word": "il",
"start": 9.6,
"end": 9.78
},
{
"word": "est",
"start": 9.78,
"end": 9.84
},
{
"word": "une",
"start": 9.84,
"end": 9.96
},
{
"word": "telle",
"start": 9.96,
"end": 10.12
},
{
"word": "seconde",
"start": 10.12,
"end": 10.38
},
{
"word": "de",
"start": 10.38,
"end": 10.48
},
{
"word": "silence",
"start": 10.48,
"end": 10.78
},
{
"word": "et",
"start": 10.78,
"end": 11.06
},
{
"word": "je",
"start": 11.06,
"end": 11.16
},
{
"word": "vous",
"start": 11.16,
"end": 11.32
},
{
"word": "parle.",
"start": 11.32,
"end": 11.68
},
{
"word": "Et",
"start": 13.28,
"end": 13.64
},
{
"word": "voilà,",
"start": 13.64,
"end": 13.96
},
{
"word": "allez",
"start": 14.36,
"end": 14.62
},
{
"word": "on",
"start": 14.62,
"end": 14.78
},
{
"word": "va",
"start": 14.78,
"end": 14.88
},
{
"word": "tester",
"start": 14.88,
"end": 15.06
},
{
"word": "ça.",
"start": 15.06,
"end": 15.36
}
]

View File

@@ -0,0 +1,382 @@
[
{
"word": "Transcription",
"start": 0.0,
"end": 0.6
},
{
"word": "technology",
"start": 0.6,
"end": 1.24
},
{
"word": "has",
"start": 1.24,
"end": 1.5
},
{
"word": "improved",
"start": 1.5,
"end": 1.96
},
{
"word": "so",
"start": 1.96,
"end": 2.32
},
{
"word": "much",
"start": 2.32,
"end": 2.68
},
{
"word": "in",
"start": 2.68,
"end": 2.94
},
{
"word": "the",
"start": 2.94,
"end": 3.02
},
{
"word": "past",
"start": 3.02,
"end": 3.24
},
{
"word": "few",
"start": 3.24,
"end": 3.5
},
{
"word": "years.",
"start": 3.5,
"end": 3.96
},
{
"word": "Have",
"start": 4.56,
"end": 4.74
},
{
"word": "you",
"start": 4.74,
"end": 4.9
},
{
"word": "noticed",
"start": 4.9,
"end": 5.26
},
{
"word": "how",
"start": 5.26,
"end": 5.52
},
{
"word": "accurate",
"start": 5.52,
"end": 6.08
},
{
"word": "real",
"start": 6.08,
"end": 6.42
},
{
"word": "-time",
"start": 6.42,
"end": 6.74
},
{
"word": "speech",
"start": 6.74,
"end": 7.24
},
{
"word": "to",
"start": 7.24,
"end": 7.46
},
{
"word": "text",
"start": 7.46,
"end": 7.78
},
{
"word": "is",
"start": 7.78,
"end": 8.0
},
{
"word": "now?",
"start": 8.0,
"end": 8.3
},
{
"word": "Absolutely.",
"start": 8.7,
"end": 9.16
},
{
"word": "I",
"start": 10.04,
"end": 10.38
},
{
"word": "use",
"start": 10.38,
"end": 10.56
},
{
"word": "it",
"start": 10.56,
"end": 10.76
},
{
"word": "all",
"start": 10.76,
"end": 10.9
},
{
"word": "the",
"start": 10.9,
"end": 11.04
},
{
"word": "time",
"start": 11.04,
"end": 11.32
},
{
"word": "for",
"start": 11.32,
"end": 11.54
},
{
"word": "taking",
"start": 11.54,
"end": 11.86
},
{
"word": "notes",
"start": 11.86,
"end": 12.16
},
{
"word": "during",
"start": 12.16,
"end": 12.54
},
{
"word": "meetings.",
"start": 12.54,
"end": 12.94
},
{
"word": "It's",
"start": 13.6,
"end": 13.8
},
{
"word": "amazing",
"start": 13.8,
"end": 14.1
},
{
"word": "how",
"start": 14.1,
"end": 14.48
},
{
"word": "it",
"start": 14.48,
"end": 14.62
},
{
"word": "can",
"start": 14.62,
"end": 14.74
},
{
"word": "recognise",
"start": 14.74,
"end": 15.24
},
{
"word": "different",
"start": 15.24,
"end": 15.68
},
{
"word": "speakers",
"start": 15.68,
"end": 16.16
},
{
"word": "and",
"start": 16.16,
"end": 16.8
},
{
"word": "even",
"start": 16.8,
"end": 17.1
},
{
"word": "add",
"start": 17.1,
"end": 17.44
},
{
"word": "punctuation.",
"start": 17.44,
"end": 18.36
},
{
"word": "Yeah,",
"start": 18.88,
"end": 19.16
},
{
"word": "but",
"start": 19.36,
"end": 19.52
},
{
"word": "sometimes",
"start": 19.52,
"end": 20.16
},
{
"word": "noise",
"start": 20.16,
"end": 20.54
},
{
"word": "can",
"start": 20.54,
"end": 20.8
},
{
"word": "still",
"start": 20.8,
"end": 21.1
},
{
"word": "cause",
"start": 21.1,
"end": 21.44
},
{
"word": "mistakes.",
"start": 21.44,
"end": 21.94
},
{
"word": "Does",
"start": 22.68,
"end": 22.9
},
{
"word": "this",
"start": 22.9,
"end": 23.12
},
{
"word": "system",
"start": 23.12,
"end": 23.46
},
{
"word": "handle",
"start": 23.46,
"end": 23.88
},
{
"word": "that",
"start": 23.88,
"end": 24.12
},
{
"word": "well?",
"start": 24.12,
"end": 24.42
},
{
"word": "It",
"start": 24.42,
"end": 25.32
},
{
"word": "does",
"start": 25.32,
"end": 25.48
},
{
"word": "a",
"start": 25.48,
"end": 25.62
},
{
"word": "pretty",
"start": 25.62,
"end": 25.88
},
{
"word": "good",
"start": 25.88,
"end": 26.08
},
{
"word": "job",
"start": 26.08,
"end": 26.32
},
{
"word": "filtering",
"start": 26.32,
"end": 26.8
},
{
"word": "noise,",
"start": 26.8,
"end": 27.18
},
{
"word": "especially",
"start": 27.36,
"end": 28.0
},
{
"word": "with",
"start": 28.0,
"end": 28.28
},
{
"word": "models",
"start": 28.28,
"end": 28.62
},
{
"word": "that",
"start": 28.62,
"end": 28.94
},
{
"word": "use",
"start": 28.94,
"end": 29.22
},
{
"word": "voice",
"start": 29.22,
"end": 29.54
},
{
"word": "active.",
"start": 29.54,
"end": 29.9
}
]

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""Generate word-level timestamped transcripts using faster-whisper (offline).
Produces one JSON file per audio with: [{word, start, end}, ...]
"""
import json
import os
from faster_whisper import WhisperModel
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
FILES = [
("00_00_07_english_1_speaker.wav", "en"),
("00_00_16_french_1_speaker.wav", "fr"),
("00_00_30_english_3_speakers.wav", "en"),
]
def main():
print("Loading faster-whisper model (base, cpu, float32)...")
model = WhisperModel("base", device="cpu", compute_type="float32")
for filename, lang in FILES:
audio_path = os.path.join(AUDIO_DIR, filename)
out_path = os.path.join(
AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json"
)
print(f"\n{'='*60}")
print(f"Transcribing: {filename} (language={lang})")
print(f"{'='*60}")
segments, info = model.transcribe(
audio_path, word_timestamps=True, language=lang
)
words = []
for segment in segments:
if segment.words:
for w in segment.words:
words.append({
"word": w.word.strip(),
"start": round(w.start, 3),
"end": round(w.end, 3),
})
print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(words, f, indent=2, ensure_ascii=False)
print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}")
print("\nDone.")
if __name__ == "__main__":
main()

BIN
benchmark_chart.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

BIN
benchmark_scatter.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

52
compose.yml Normal file
View File

@@ -0,0 +1,52 @@
services:
wlk-gpu-sortformer:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
image: wlk:gpu-sortformer
gpus: all
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--model", "medium", "--diarization", "--pcm-input"]
wlk-gpu-voxtral:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
image: wlk:gpu-voxtral
gpus: all
ports:
- "8001:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--backend", "voxtral", "--pcm-input"]
wlk-cpu:
build:
context: .
dockerfile: Dockerfile.cpu
args:
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
image: wlk:cpu
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
volumes:
hf-cache:

View File

@@ -4,27 +4,21 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.18"
version = "0.2.19"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
]
authors = [{ name = "Quentin Fuxa" }]
license = { file = "LICENSE" }
requires-python = ">=3.9"
requires-python = ">=3.11, <3.14"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
"Topic :: Multimedia :: Sound/Audio :: Speech",
]
dependencies = [
"fastapi",
@@ -32,19 +26,91 @@ dependencies = [
"soundfile",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.optional-dependencies]
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"]
mlx-whisper = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
]
voxtral-mlx = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
"mistral-common[audio]",
]
voxtral-hf = [
"transformers>=5.2.0; python_version >= '3.10'",
"mistral-common[audio]",
"accelerate>=0.12",
]
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
cu129 = [
"torch>=2.0.0",
"torchaudio>=2.0.0",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
]
diarization-sortformer = [
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
]
diarization-diart = [
"diart",
"torch<2.9.0",
"torchaudio<2.9.0",
"torchvision<0.24.0",
]
[dependency-groups]
dev = ["rich>=14.3.3"]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "cu129" },
],
[
{ extra = "diarization-diart" },
{ extra = "cu129" },
],
[
{ extra = "voxtral-hf" },
{ extra = "diarization-sortformer" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchaudio = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
@@ -64,7 +130,8 @@ packages = [
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.local_agreement",
"whisperlivekit.silero_vad_models"
"whisperlivekit.voxtral_mlx",
"whisperlivekit.silero_vad_models",
]
[tool.setuptools.package-data]

291
run_benchmark.py Normal file
View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
Comprehensive benchmark runner for WhisperLiveKit.
Tests all available backend+policy combinations across multiple audio files,
model sizes, and VAC on/off configurations. Outputs structured JSON that
is consumed by the report generator.
Usage:
python run_benchmark.py # full benchmark
python run_benchmark.py --quick # subset (tiny models, fewer combos)
python run_benchmark.py --json results.json # custom output path
"""
import argparse
import asyncio
import gc
import json
import logging
import platform
import subprocess
import sys
import time
from dataclasses import asdict
from pathlib import Path
logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger("benchmark")
logger.setLevel(logging.INFO)
# Re-use harness functions
sys.path.insert(0, str(Path(__file__).parent))
from test_backend_offline import (
AUDIO_TESTS_DIR,
SAMPLE_RATE,
TestResult,
create_engine,
discover_audio_files,
download_sample_audio,
load_audio,
run_test,
)
CACHE_DIR = Path(__file__).parent / ".test_cache"
def get_system_info() -> dict:
"""Collect system metadata for the report."""
info = {
"platform": platform.platform(),
"machine": platform.machine(),
"processor": platform.processor(),
"python_version": platform.python_version(),
}
# macOS: get chip info
try:
chip = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"], text=True
).strip()
info["cpu"] = chip
except Exception:
info["cpu"] = platform.processor()
# RAM
try:
mem_bytes = int(
subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()
)
info["ram_gb"] = round(mem_bytes / (1024**3))
except Exception:
info["ram_gb"] = None
# Backend versions
versions = {}
try:
import faster_whisper
versions["faster-whisper"] = faster_whisper.__version__
except ImportError:
pass
try:
import mlx_whisper # noqa: F401
versions["mlx-whisper"] = "installed"
except ImportError:
pass
try:
import mlx.core as mx
versions["mlx"] = mx.__version__
except ImportError:
pass
try:
import transformers
versions["transformers"] = transformers.__version__
except ImportError:
pass
try:
import torch
versions["torch"] = torch.__version__
except ImportError:
pass
info["backend_versions"] = versions
return info
def detect_combos(quick: bool = False) -> list:
"""Build list of (backend, policy, model_size) combos to test."""
combos = []
# Model sizes to test
model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"]
# faster-whisper
try:
import faster_whisper # noqa: F401
for model in model_sizes:
combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model})
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model})
except ImportError:
pass
# mlx-whisper
try:
import mlx_whisper # noqa: F401
for model in model_sizes:
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model})
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model})
except ImportError:
pass
# voxtral-mlx (single model, single policy)
try:
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""})
except ImportError:
pass
# voxtral HF (single model, single policy)
try:
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""})
except ImportError:
pass
return combos
def collect_audio_files() -> list:
"""Collect all benchmark audio files."""
files = []
# audio_tests/ directory
if AUDIO_TESTS_DIR.is_dir():
files.extend(discover_audio_files(str(AUDIO_TESTS_DIR)))
# JFK sample
jfk = CACHE_DIR / "jfk.wav"
if not jfk.exists():
jfk = download_sample_audio()
if jfk.exists():
files.append(jfk)
return files
async def run_single_combo(
combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float,
) -> list:
"""Run one backend+policy+model combo across all audio files."""
backend = combo["backend"]
policy = combo["policy"]
model = combo["model"]
results = []
try:
engine = create_engine(
backend=backend,
model_size=model,
lan=lan,
vac=vac,
policy=policy,
)
# Quiet noisy loggers
for mod in (
"whisperlivekit.audio_processor",
"whisperlivekit.simul_whisper",
"whisperlivekit.tokens_alignment",
"whisperlivekit.simul_whisper.align_att_base",
"whisperlivekit.simul_whisper.simul_whisper",
):
logging.getLogger(mod).setLevel(logging.WARNING)
for audio_path in audio_files:
duration = len(load_audio(str(audio_path))) / SAMPLE_RATE
if duration > max_duration:
logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)")
continue
file_lan = lan
if "french" in audio_path.name.lower() and lan == "en":
file_lan = "fr"
audio = load_audio(str(audio_path))
result = await run_test(
engine, audio, chunk_ms=100, realtime=False,
audio_file=audio_path.name, backend=backend,
policy=policy, lan=file_lan,
)
# Tag with extra metadata
result_dict = asdict(result)
result_dict["model_size"] = model
result_dict["vac"] = vac
results.append(result_dict)
except Exception as e:
logger.error(f" FAILED: {e}")
import traceback
traceback.print_exc()
return results
async def run_full_benchmark(combos, audio_files, max_duration=60.0):
"""Run all combos with VAC on and off."""
all_results = []
total = len(combos) * 2 # x2 for VAC on/off
idx = 0
for combo in combos:
for vac in [True, False]:
idx += 1
vac_str = "VAC=on" if vac else "VAC=off"
desc = f"{combo['backend']} / {combo['policy']}"
if combo["model"]:
desc += f" / {combo['model']}"
desc += f" / {vac_str}"
print(f"\n{'='*70}")
print(f"[{idx}/{total}] {desc}")
print(f"{'='*70}")
results = await run_single_combo(
combo, audio_files, vac=vac, lan="en", max_duration=max_duration,
)
all_results.extend(results)
# Free memory between combos
gc.collect()
return all_results
def main():
parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark")
parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos")
parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path")
parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds")
args = parser.parse_args()
system_info = get_system_info()
combos = detect_combos(quick=args.quick)
audio_files = collect_audio_files()
print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM")
print(f"Backends: {list(system_info['backend_versions'].keys())}")
print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}")
print(f"Audio files: {[f.name for f in audio_files]}")
print()
t0 = time.time()
all_results = asyncio.run(
run_full_benchmark(combos, audio_files, max_duration=args.max_duration)
)
total_time = time.time() - t0
output = {
"system_info": system_info,
"benchmark_date": time.strftime("%Y-%m-%d %H:%M"),
"total_benchmark_time_s": round(total_time, 1),
"n_combos": len(combos) * 2,
"n_audio_files": len(audio_files),
"results": all_results,
}
Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,580 @@
#!/usr/bin/env python3
"""Offline Python support matrix runner for WhisperLiveKit."""
from __future__ import annotations
import argparse
import os
import shlex
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
try:
from rich.console import Console
from rich.table import Table
HAS_RICH = True
except Exception:
HAS_RICH = False
SAMPLE_URL = (
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
)
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
CONSOLE = Console() if HAS_RICH else None
@dataclass(frozen=True)
class MatrixRow:
row_id: str
extras: tuple[str, ...]
backend: str
policy: str
diarization_backend: str
requires_gpu: bool = False
CASES = (
MatrixRow(
row_id="fw-diart-cpu",
extras=("test", "cpu", "diarization-diart"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="diart",
),
MatrixRow(
row_id="fw-sortformer-cpu",
extras=("test", "cpu", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
),
MatrixRow(
row_id="fw-sortformer-gpu",
extras=("test", "cu129", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
requires_gpu=True,
),
MatrixRow(
row_id="voxtral-diart-cpu",
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
backend="voxtral",
policy="voxtral",
diarization_backend="diart",
),
)
EXPECTED_FAILURE_CASES = {
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
}
UNSUPPORTED_CASES = {
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
}
@dataclass(frozen=True)
class CaseResult:
python_version: str
row_id: str
status: Literal["PASS", "FAIL", "N/A"]
reason: str
duration_sec: float
hint: str = ""
log_path: str = ""
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Minimal WhisperLiveKit offline support matrix"
)
parser.add_argument(
"--timeout-sec",
type=int,
default=300,
help="Per-case timeout in seconds (default: 300)",
)
parser.add_argument(
"--logs-dir",
default=str(DEFAULT_LOGS_DIR),
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
)
return parser.parse_args()
def safe_slug(text: str) -> str:
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
def status_style(status: str) -> str:
if status == "PASS":
return "green"
if status == "FAIL":
return "bold red"
if status == "N/A":
return "yellow"
return "white"
def print_line(message: str, style: str | None = None) -> None:
if CONSOLE is None:
print(message)
return
if style:
CONSOLE.print(message, style=style, highlight=False)
else:
CONSOLE.print(message, highlight=False)
def tail_text(text: str | None, max_chars: int = 220) -> str:
if not text:
return ""
normalized = " ".join(text.split())
if len(normalized) <= max_chars:
return normalized
return normalized[-max_chars:]
def run_command(
cmd: list[str],
cwd: Path,
env: dict[str, str],
timeout: int | None = None,
log_path: Path | None = None,
log_section: str | None = None,
) -> subprocess.CompletedProcess[str]:
def _append_log(
*,
command: list[str],
section: str,
returncode: int | None,
stdout: str | None,
stderr: str | None,
timed_out: bool = False,
) -> None:
if log_path is None:
return
log_path.parent.mkdir(parents=True, exist_ok=True)
with log_path.open("a", encoding="utf-8") as f:
f.write(f"\n=== {section} ===\n")
f.write(f"$ {shlex.join(command)}\n")
if timed_out:
f.write("status: timeout\n")
else:
f.write(f"status: exit_code={returncode}\n")
if stdout:
f.write("--- stdout ---\n")
f.write(stdout)
if not stdout.endswith("\n"):
f.write("\n")
if stderr:
f.write("--- stderr ---\n")
f.write(stderr)
if not stderr.endswith("\n"):
f.write("\n")
section = log_section or "command"
try:
proc = subprocess.run(
cmd,
cwd=str(cwd),
env=env,
text=True,
capture_output=True,
check=False,
timeout=timeout,
)
except subprocess.TimeoutExpired as exc:
_append_log(
command=cmd,
section=section,
returncode=None,
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
timed_out=True,
)
raise
_append_log(
command=cmd,
section=section,
returncode=proc.returncode,
stdout=proc.stdout,
stderr=proc.stderr,
)
return proc
def detect_gpu_available() -> bool:
try:
proc = subprocess.run(
["nvidia-smi", "-L"],
text=True,
capture_output=True,
check=False,
timeout=10,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return False
return proc.returncode == 0
def download_sample(repo_root: Path) -> Path:
target = repo_root / SAMPLE_PATH
target.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"curl",
"--fail",
"--location",
"--silent",
"--show-error",
SAMPLE_URL,
"--output",
str(target),
]
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
if proc.returncode != 0:
hint = tail_text(proc.stderr or proc.stdout)
raise RuntimeError(f"sample_download_failed: {hint}")
return target
def sync_case_environment(
repo_root: Path,
python_version: str,
row: MatrixRow,
env_dir: Path,
log_path: Path,
) -> tuple[bool, str]:
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
for extra in row.extras:
cmd.extend(["--extra", extra])
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
proc = run_command(
cmd,
cwd=repo_root,
env=env,
log_path=log_path,
log_section="sync",
)
if proc.returncode != 0:
return False, tail_text(proc.stderr or proc.stdout)
return True, ""
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
if result.status != "FAIL" or not expected_reason:
return result
override_hint = result.hint
if result.reason:
override_hint = (
f"expected_failure_override original_reason={result.reason}; {override_hint}"
if override_hint
else f"expected_failure_override original_reason={result.reason}"
)
return CaseResult(
python_version=result.python_version,
row_id=result.row_id,
status="N/A",
reason=expected_reason,
duration_sec=result.duration_sec,
hint=override_hint,
log_path=result.log_path,
)
def build_offline_command(
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
) -> tuple[list[str], int | None]:
base_cmd = [
"uv",
"run",
"--python",
python_version,
"--no-sync",
"python",
"test_backend_offline.py",
"--backend",
row.backend,
"--policy",
row.policy,
"--audio",
str(sample_audio),
"--model",
"tiny",
"--diarization",
"--diarization-backend",
row.diarization_backend,
"--lan",
"en",
"--no-realtime",
]
if shutil.which("timeout"):
return ["timeout", str(timeout_sec), *base_cmd], None
return base_cmd, timeout_sec
def run_case(
repo_root: Path,
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
gpu_available: bool,
logs_dir: Path,
) -> CaseResult:
start = time.monotonic()
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
log_path = logs_dir / f"run-{case_slug}.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text("", encoding="utf-8")
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
if unsupported_reason:
log_path.write_text(
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
encoding="utf-8",
)
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason=unsupported_reason,
duration_sec=0.0,
hint="unsupported_case_precheck",
log_path=str(log_path),
)
if row.requires_gpu and not gpu_available:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason="gpu_unavailable",
duration_sec=0.0,
hint="nvidia-smi unavailable or failed",
log_path=str(log_path),
)
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
sync_ok, sync_hint = sync_case_environment(
repo_root,
python_version,
row,
env_dir,
log_path=log_path,
)
if not sync_ok:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="dependency_sync_failed",
duration_sec=round(time.monotonic() - start, 3),
hint=sync_hint,
log_path=str(log_path),
)
cmd, process_timeout = build_offline_command(
python_version, row, sample_audio, timeout_sec
)
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
if row.requires_gpu:
env.pop("CUDA_VISIBLE_DEVICES", None)
else:
env["CUDA_VISIBLE_DEVICES"] = ""
try:
proc = run_command(
cmd,
cwd=repo_root,
env=env,
timeout=process_timeout,
log_path=log_path,
log_section="offline",
)
except subprocess.TimeoutExpired as exc:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="offline_timeout",
duration_sec=round(time.monotonic() - start, 3),
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
log_path=str(log_path),
)
hint = tail_text(proc.stderr or proc.stdout)
if proc.returncode == 0:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="PASS",
reason="ok",
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason=reason,
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
def print_summary(results: list[CaseResult]) -> None:
pass_count = sum(1 for row in results if row.status == "PASS")
fail_count = sum(1 for row in results if row.status == "FAIL")
na_count = sum(1 for row in results if row.status == "N/A")
if CONSOLE is None:
print("\n[matrix] results")
print("python | row | status | reason | duration_s")
print("---|---|---|---|---")
for result in results:
print(
f"{result.python_version} | {result.row_id} | {result.status} | "
f"{result.reason} | {result.duration_sec:.3f}"
)
print(
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
f"na={na_count} total={len(results)}"
)
else:
table = Table(title="Support Matrix Results")
table.add_column("Python", style="cyan", no_wrap=True)
table.add_column("Row", style="white")
table.add_column("Status", no_wrap=True)
table.add_column("Reason")
table.add_column("Duration (s)", justify="right", no_wrap=True)
for result in results:
table.add_row(
result.python_version,
result.row_id,
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
result.reason,
f"{result.duration_sec:.3f}",
)
CONSOLE.print()
CONSOLE.print(table)
CONSOLE.print(
f"[bold]Summary[/bold] "
f"pass=[green]{pass_count}[/green] "
f"fail=[bold red]{fail_count}[/bold red] "
f"na=[yellow]{na_count}[/yellow] "
f"total={len(results)}"
)
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
if diagnostics:
if CONSOLE is None:
print("\n[matrix] diagnostics (failed/n-a cases)")
for row in diagnostics:
print(
f"- py={row.python_version} row={row.row_id} "
f"status={row.status} reason={row.reason}"
)
print(f" hint: {row.hint}")
if row.log_path:
print(f" log: {row.log_path}")
else:
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
diagnostics_table.add_column("Case", style="cyan")
diagnostics_table.add_column("Status", no_wrap=True)
diagnostics_table.add_column("Reason")
diagnostics_table.add_column("Hint")
diagnostics_table.add_column("Log")
for row in diagnostics:
diagnostics_table.add_row(
f"py={row.python_version} {row.row_id}",
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
row.reason,
row.hint,
row.log_path,
)
CONSOLE.print()
CONSOLE.print(diagnostics_table)
def main() -> int:
args = parse_args()
if args.timeout_sec <= 0:
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
return 1
repo_root = Path(__file__).resolve().parents[1]
logs_dir = (repo_root / args.logs_dir).resolve()
logs_dir.mkdir(parents=True, exist_ok=True)
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
try:
sample_audio = download_sample(repo_root)
except Exception as exc: # pragma: no cover - straightforward failure path
if CONSOLE is None:
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
else:
CONSOLE.print(
f"[matrix] sample_download_failed: {exc}",
style="bold red",
highlight=False,
)
return 1
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
gpu_available = detect_gpu_available()
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
results: list[CaseResult] = []
for python_version in PYTHON_VERSIONS:
for row in CASES:
print_line(
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
)
result = run_case(
repo_root=repo_root,
python_version=python_version,
row=row,
sample_audio=sample_audio,
timeout_sec=args.timeout_sec,
gpu_available=gpu_available,
logs_dir=logs_dir,
)
result = apply_expected_failure_policy(result)
results.append(result)
print_line(
f"[matrix] {result.status} py={result.python_version} "
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
style=status_style(result.status),
)
if result.log_path:
print_line(f"[matrix] log={result.log_path}", style="dim")
print_summary(results)
fail_count = sum(1 for row in results if row.status == "FAIL")
return 1 if fail_count else 0
if __name__ == "__main__":
raise SystemExit(main())

803
test_backend_offline.py Normal file
View File

@@ -0,0 +1,803 @@
#!/usr/bin/env python3
"""
Offline test harness and benchmark suite for WhisperLiveKit backends.
Simulates a client-server session by feeding audio files as PCM bytes through
the full AudioProcessor pipeline (the same path used by the WebSocket server),
without needing a browser or microphone.
Computes WER (Word Error Rate) and timestamp accuracy when ground truth
transcript files (.transcript.json) are available alongside audio files.
Usage:
# Test with a single audio file:
python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav
# Test all files in audio_tests/:
python test_backend_offline.py --backend faster-whisper --no-realtime
# Override streaming policy:
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime
# Multi-backend benchmark (auto-detects all installed backends):
python test_backend_offline.py --benchmark --no-realtime
# Export results as JSON:
python test_backend_offline.py --benchmark --no-realtime --json results.json
# Insert silence for testing silence handling:
python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0
"""
import argparse
import asyncio
import json
import logging
import sys
import time
import urllib.request
from pathlib import Path
from dataclasses import dataclass, asdict, field
from typing import List, Optional
import numpy as np
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("test_offline")
logger.setLevel(logging.INFO)
SAMPLE_RATE = 16000
JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
CACHE_DIR = Path(__file__).parent / ".test_cache"
AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests"
AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
@dataclass
class WordTimestamp:
"""Word with its start/end time."""
word: str
start: float
end: float
@dataclass
class TestResult:
"""Structured result from a single test run."""
audio_file: str
audio_duration_s: float
backend: str
policy: str
language: str
chunk_ms: int
realtime_pacing: bool
# Timing
processing_time_s: float
rtf: float # real-time factor
# Transcription output
transcription: str
n_lines: int
n_responses: int
# WER metrics (None if no ground truth)
wer: Optional[float] = None
wer_details: Optional[dict] = None
# Timestamp accuracy (None if no ground truth)
timestamp_mae: Optional[float] = None
timestamp_max_delta: Optional[float] = None
timestamp_median_delta: Optional[float] = None
# Word-level timestamps
word_timestamps: List[WordTimestamp] = field(default_factory=list)
# Raw last response
last_response: Optional[dict] = None
def download_sample_audio() -> Path:
"""Download the jfk.wav sample if not cached."""
CACHE_DIR.mkdir(exist_ok=True)
path = CACHE_DIR / "jfk.wav"
if not path.exists():
logger.info(f"Downloading sample audio to {path} ...")
urllib.request.urlretrieve(JFK_WAV_URL, path)
logger.info("Done.")
return path
def load_audio(path: str) -> np.ndarray:
"""Load audio file as float32 mono 16kHz numpy array.
Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa).
"""
ext = Path(path).suffix.lower()
if ext in (".mp3", ".ogg", ".m4a"):
import librosa
audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
return audio.astype(np.float32)
import soundfile as sf
audio, sr = sf.read(path, dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
return audio
def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray:
"""Insert silence into audio at a given position.
Args:
audio: Float32 mono audio array at SAMPLE_RATE.
silence_sec: Duration of silence to insert in seconds.
position_sec: Position in seconds where silence starts.
Returns:
New audio array with silence inserted.
"""
pos_samples = int(position_sec * SAMPLE_RATE)
silence_samples = int(silence_sec * SAMPLE_RATE)
pos_samples = min(pos_samples, len(audio))
silence = np.zeros(silence_samples, dtype=np.float32)
return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]])
def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
"""Convert float32 audio to s16le PCM bytes (what the browser sends)."""
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
def create_engine(
backend: str, model_size: str, lan: str,
diarization: bool = False,
diarization_backend: str = "",
vac: bool = True,
policy: str = "",
):
"""Create a TranscriptionEngine with the given backend config."""
import gc
from whisperlivekit.core import TranscriptionEngine
# Reset singleton so we get a fresh instance
TranscriptionEngine._instance = None
TranscriptionEngine._initialized = False
gc.collect()
kwargs = dict(
backend=backend,
lan=lan,
pcm_input=True,
vac=vac,
transcription=True,
diarization=diarization,
)
if diarization_backend:
kwargs["diarization_backend"] = diarization_backend
if model_size:
kwargs["model_size"] = model_size
if policy:
kwargs["backend_policy"] = policy
return TranscriptionEngine(**kwargs)
def _extract_text_from_response(response_dict: dict) -> str:
"""Extract full transcription text from a FrontData dict."""
def _strip_or_empty(value: object) -> str:
return value.strip() if isinstance(value, str) else ""
segments = response_dict.get("lines", [])
full_text = " ".join(
text
for seg in segments
if isinstance(seg, dict)
for text in [_strip_or_empty(seg.get("text"))]
if text
)
buf = _strip_or_empty(response_dict.get("buffer_transcription"))
if buf:
full_text = f"{full_text} {buf}".strip() if full_text else buf
return full_text
async def run_test(
engine, audio: np.ndarray, chunk_ms: int, realtime: bool,
audio_file: str = "", backend: str = "", policy: str = "", lan: str = "",
) -> TestResult:
"""
Simulate a client session through the full AudioProcessor pipeline.
1. Create AudioProcessor (one per "client session")
2. Start async pipeline (transcription_processor, results_formatter, etc.)
3. Feed audio as PCM bytes in timed chunks
4. Collect and display FrontData responses
5. Signal EOF and cleanup
"""
from whisperlivekit.audio_processor import AudioProcessor
chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000)
total_samples = len(audio)
audio_duration = total_samples / SAMPLE_RATE
logger.info(
f"Audio: {audio_duration:.2f}s | "
f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | "
f"Steps: {total_samples // chunk_samples + 1} | "
f"Realtime: {realtime}"
)
# --- Server side: create processor and start pipeline ---
processor = AudioProcessor(transcription_engine=engine)
results_generator = await processor.create_tasks()
# Collect results in background (like handle_websocket_results)
all_responses = []
response_count = 0
last_printed_text = ""
async def collect_results():
nonlocal response_count, last_printed_text
async for response in results_generator:
all_responses.append(response)
response_count += 1
d = response.to_dict()
# Only print when transcription text actually changes
current_text = _extract_text_from_response(d)
if current_text and current_text != last_printed_text:
buf = d.get("buffer_transcription")
buf = buf.strip() if isinstance(buf, str) else ""
committed = current_text
if buf and committed.endswith(buf):
committed = committed[:-len(buf)].strip()
# Show committed text + buffer separately
display = committed
if buf:
display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m"
print(f" > {display}", flush=True)
last_printed_text = current_text
result_task = asyncio.create_task(collect_results())
# --- Client side: feed audio as PCM bytes ---
t_start = time.time()
for offset in range(0, total_samples, chunk_samples):
chunk = audio[offset : offset + chunk_samples]
pcm_bytes = float32_to_s16le_bytes(chunk)
await processor.process_audio(pcm_bytes)
if realtime:
await asyncio.sleep(chunk_ms / 1000)
feed_elapsed = time.time() - t_start
logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...")
# Signal end of audio (like client disconnect / empty message)
await processor.process_audio(None)
# Wait for pipeline to drain completely
try:
await asyncio.wait_for(result_task, timeout=120.0)
except asyncio.TimeoutError:
logger.warning("Timed out waiting for results. Proceeding with cleanup.")
result_task.cancel()
try:
await result_task
except asyncio.CancelledError:
pass
# --- Capture word-level timestamps before cleanup ---
word_timestamps = []
try:
state = await processor.get_current_state()
for token in state.tokens:
if hasattr(token, 'start') and hasattr(token, 'text') and token.text:
word_timestamps.append(WordTimestamp(
word=token.text.strip(),
start=round(token.start, 3),
end=round(token.end, 3),
))
except Exception as e:
logger.warning(f"Could not capture word timestamps: {e}")
# Cleanup
await processor.cleanup()
total_elapsed = time.time() - t_start
# --- Build result ---
transcription = ""
n_lines = 0
last_response_dict = None
if all_responses:
last = all_responses[-1].to_dict()
last_response_dict = last
n_lines = len(last.get("lines", []))
transcription = _extract_text_from_response(last)
# --- Compute WER and timestamp accuracy against ground truth ---
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
wer_val = None
wer_details = None
ts_mae = None
ts_max_delta = None
ts_median_delta = None
gt_path = Path(audio_file).with_suffix(".transcript.json")
if not gt_path.exists():
gt_path = AUDIO_TESTS_DIR / gt_path
gt = None
if gt_path.exists():
with open(gt_path) as f:
gt = json.load(f)
# WER
gt_text = " ".join(w["word"] for w in gt)
wer_result = compute_wer(gt_text, transcription)
wer_val = round(wer_result["wer"], 4)
wer_details = wer_result
# Timestamp accuracy
if word_timestamps:
pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps]
ts_result = compute_timestamp_accuracy(pred_dicts, gt)
ts_mae = ts_result["mae_start"]
ts_max_delta = ts_result["max_delta_start"]
ts_median_delta = ts_result["median_delta_start"]
result = TestResult(
audio_file=audio_file,
audio_duration_s=round(audio_duration, 2),
backend=backend,
policy=policy,
language=lan,
chunk_ms=chunk_ms,
realtime_pacing=realtime,
processing_time_s=round(total_elapsed, 2),
rtf=round(total_elapsed / audio_duration, 2),
transcription=transcription,
n_lines=n_lines,
n_responses=response_count,
wer=wer_val,
wer_details=wer_details,
timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None,
timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None,
timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None,
word_timestamps=word_timestamps,
last_response=last_response_dict,
)
# --- Print summary ---
print(f"\n{'=' * 60}")
print(f"RESULT: {audio_file}")
print(f"{'=' * 60}")
print(f"Transcription: {transcription}")
print(f"Lines: {n_lines} | Responses: {response_count}")
print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x")
if wer_val is not None:
print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})")
# Print word timestamps if available
if word_timestamps:
print(f"\nWord timestamps ({len(word_timestamps)} words):")
for wt in word_timestamps:
print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}")
# Detailed comparison with ground truth
if gt:
print(f"\n vs Ground truth ({len(gt)} words):")
max_words = max(len(word_timestamps), len(gt))
for i in range(max_words):
pred = word_timestamps[i] if i < len(word_timestamps) else None
ref = gt[i] if i < len(gt) else None
p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30
r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else ""
delta = ""
if pred and ref:
d = pred.start - ref['start']
delta = f" Δstart={d:+.2f}"
print(f" {p_str} | {r_str}{delta}")
if ts_mae is not None:
print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s")
print(f"{'=' * 60}")
return result
def discover_audio_files(directory: str) -> List[Path]:
"""Find all supported audio files in directory."""
d = Path(directory)
files = sorted(
p for p in d.iterdir()
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
)
return files
async def run_all_tests(
engine, audio_files: List[Path], chunk_ms: int, realtime: bool,
backend: str, policy: str, lan: str, max_duration: float = 60.0,
silence_insertions: Optional[List[List[float]]] = None,
) -> List[TestResult]:
"""Run tests on multiple audio files sequentially."""
results = []
for audio_path in audio_files:
# Detect language from filename if "french" in name
file_lan = lan
if "french" in audio_path.name.lower() and lan == "en":
file_lan = "fr"
logger.info(f"Auto-detected language 'fr' from filename")
audio = load_audio(str(audio_path))
# Insert silence segments (applied in reverse position order to keep offsets valid)
if silence_insertions:
for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True):
logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s")
audio = insert_silence(audio, secs, at_sec)
duration = len(audio) / SAMPLE_RATE
if duration > max_duration:
logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)")
continue
print(f"\n{'#' * 60}")
print(f"# Testing: {audio_path.name} ({duration:.1f}s)")
print(f"{'#' * 60}")
result = await run_test(
engine, audio, chunk_ms, realtime,
audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan,
)
results.append(result)
return results
def print_benchmark_summary(results: List[TestResult]):
"""Print a tabular summary of all test results."""
print(f"\n{'=' * 110}")
print("BENCHMARK SUMMARY")
print(f"{'=' * 110}")
print(
f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} "
f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}"
)
print(f"{'-' * 110}")
for r in results:
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
print(
f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s "
f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}"
)
print(f"{'-' * 110}")
total_audio = sum(r.audio_duration_s for r in results)
total_time = sum(r.processing_time_s for r in results)
avg_rtf = total_time / total_audio if total_audio > 0 else 0
wer_vals = [r.wer for r in results if r.wer is not None]
avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None]
avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
print(
f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s "
f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}"
)
print(f"{'=' * 110}")
# Print transcription excerpts
print(f"\nTRANSCRIPTIONS:")
print(f"{'-' * 110}")
for r in results:
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
print(f" {r.audio_file}:")
print(f" {excerpt}")
print(f"{'=' * 110}")
def detect_available_backends() -> List[dict]:
"""Probe which backends can be imported and return (backend, policy) combos.
Returns list of dicts with keys: backend, policy, description.
"""
combos = []
# faster-whisper
try:
import faster_whisper # noqa: F401
combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"})
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"})
except ImportError:
pass
# mlx-whisper (macOS only)
try:
import mlx_whisper # noqa: F401
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"})
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"})
except ImportError:
pass
# openai-whisper
try:
import whisper # noqa: F401
combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"})
combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"})
except ImportError:
pass
# voxtral-mlx
try:
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"})
except ImportError:
pass
# voxtral (HuggingFace)
try:
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"})
except ImportError:
pass
return combos
def print_cross_backend_comparison(all_results: List[TestResult]):
"""Print a comparison table across backends and policies."""
print(f"\n{'=' * 110}")
print("CROSS-BACKEND BENCHMARK COMPARISON")
print(f"{'=' * 110}")
print(
f"{'Backend':<18} {'Policy':<16} {'File':<30} "
f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}"
)
print(f"{'-' * 110}")
for r in all_results:
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
rtf_str = f"{r.rtf:.2f}x"
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -"
# Truncate filename for readability
fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file
print(
f"{r.backend:<18} {r.policy:<16} {fname:<30} "
f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}"
)
print(f"{'-' * 110}")
# Per-backend averages
from collections import defaultdict
by_combo = defaultdict(list)
for r in all_results:
by_combo[(r.backend, r.policy)].append(r)
print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}")
print(f"{'-' * 80}")
for (backend, policy), group in sorted(by_combo.items()):
wer_vals = [r.wer for r in group if r.wer is not None]
rtf_vals = [r.rtf for r in group]
mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None]
avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x"
avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
print(
f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}"
)
print(f"{'=' * 110}")
def _quiet_loggers(verbose: bool):
"""Set internal module log levels to reduce noise."""
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
else:
for mod in (
"whisperlivekit.audio_processor", "whisperlivekit.simul_whisper",
"whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base",
"whisperlivekit.simul_whisper.simul_whisper",
):
logging.getLogger(mod).setLevel(logging.WARNING)
async def run_benchmark(
audio_files: List[Path], chunk_ms: int, realtime: bool,
model_size: str, lan: str, max_duration: float, vac: bool,
verbose: bool,
) -> List[TestResult]:
"""Run benchmark across all available backend+policy combinations."""
combos = detect_available_backends()
if not combos:
logger.error("No backends available. Install at least one ASR backend.")
return []
logger.info(f"Detected {len(combos)} backend+policy combinations:")
for c in combos:
logger.info(f" - {c['description']}")
all_results = []
for i, combo in enumerate(combos, 1):
backend = combo["backend"]
policy = combo["policy"]
desc = combo["description"]
print(f"\n{'*' * 70}")
print(f"* BENCHMARK {i}/{len(combos)}: {desc}")
print(f"{'*' * 70}")
try:
engine = create_engine(
backend, model_size, lan, vac=vac, policy=policy,
)
_quiet_loggers(verbose)
results = await run_all_tests(
engine, audio_files, chunk_ms, realtime,
backend=backend, policy=policy, lan=lan,
max_duration=max_duration,
)
all_results.extend(results)
except Exception as e:
logger.error(f"Failed to run {desc}: {e}")
import traceback
traceback.print_exc()
return all_results
def main():
parser = argparse.ArgumentParser(
description="Offline backend test harness (AudioProcessor-level)"
)
parser.add_argument(
"--backend", default="faster-whisper",
help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.",
)
parser.add_argument(
"--policy", default="",
help="Override backend policy: localagreement, simulstreaming, voxtral.",
)
parser.add_argument(
"--audio", default=None,
help="Path to a single audio file (WAV, MP3, FLAC, etc.).",
)
parser.add_argument(
"--audio-dir", default=None,
help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.",
)
parser.add_argument(
"--chunk-ms", type=int, default=100,
help="Chunk size in milliseconds (simulates real-time interval).",
)
parser.add_argument(
"--model", default="", dest="model_size",
help="Model size or HF repo ID.",
)
parser.add_argument("--lan", default="en", help="Language code.")
parser.add_argument(
"--no-realtime", action="store_true",
help="Skip real-time pacing between chunks (faster but less realistic).",
)
parser.add_argument(
"--no-vac", action="store_true",
help="Disable Voice Activity Classification (send all audio without silence filtering).",
)
parser.add_argument(
"--diarization", action="store_true",
help="Enable speaker diarization.",
)
parser.add_argument(
"--diarization-backend",
default="",
choices=["diart", "sortformer"],
help="Diarization backend when --diarization is enabled.",
)
parser.add_argument(
"--benchmark", action="store_true",
help="Run benchmark across all detected backend+policy combinations.",
)
parser.add_argument(
"--json", default=None, dest="json_output",
help="Write structured JSON results to this file.",
)
parser.add_argument(
"--max-duration", type=float, default=60.0,
help="Skip audio files longer than this many seconds (default: 60).",
)
parser.add_argument(
"--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"),
action="append", default=[],
help="Insert SECS of silence at AT_SEC position. Can be repeated. "
"E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0",
)
parser.add_argument(
"-v", "--verbose", action="store_true",
help="Show debug-level logs from all components.",
)
args = parser.parse_args()
realtime = not args.no_realtime
vac = not args.no_vac
# Resolve audio file(s)
if args.audio:
audio_files = [Path(args.audio)]
elif args.audio_dir:
audio_files = discover_audio_files(args.audio_dir)
elif AUDIO_TESTS_DIR.is_dir():
audio_files = discover_audio_files(str(AUDIO_TESTS_DIR))
else:
# Fall back to jfk.wav download
audio_files = [download_sample_audio()]
if not audio_files:
logger.error("No audio files found.")
sys.exit(1)
logger.info(f"Audio files: {[f.name for f in audio_files]}")
if args.benchmark:
# --- Multi-backend benchmark mode ---
all_results = asyncio.run(
run_benchmark(
audio_files, args.chunk_ms, realtime,
args.model_size, args.lan, args.max_duration, vac,
args.verbose,
)
)
if all_results:
print_cross_backend_comparison(all_results)
results = all_results
else:
# --- Single-backend mode ---
policy = args.policy
logger.info(f"Creating {args.backend} engine...")
engine = create_engine(
args.backend, args.model_size, args.lan,
diarization=args.diarization,
diarization_backend=args.diarization_backend,
vac=vac,
policy=policy,
)
logger.info("Engine ready.")
_quiet_loggers(args.verbose)
results = asyncio.run(
run_all_tests(
engine, audio_files, args.chunk_ms, realtime,
args.backend, policy, args.lan,
max_duration=args.max_duration,
silence_insertions=args.insert_silence or None,
)
)
if len(results) > 1:
print_benchmark_summary(results)
# JSON output
if args.json_output and results:
json_results = []
for r in results:
d = asdict(r)
d.pop("last_response", None) # too verbose for summary
json_results.append(d)
Path(args.json_output).write_text(
json.dumps(json_results, indent=2, ensure_ascii=False)
)
logger.info(f"Results written to {args.json_output}")
if __name__ == "__main__":
main()

0
tests/__init__.py Normal file
View File

58
tests/conftest.py Normal file
View File

@@ -0,0 +1,58 @@
"""Shared pytest fixtures for WhisperLiveKit tests."""
import json
from pathlib import Path
from types import SimpleNamespace
import pytest
from whisperlivekit.timed_objects import ASRToken, Silence, Transcript
AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests"
@pytest.fixture
def sample_tokens():
"""A short sequence of ASRToken objects."""
return [
ASRToken(start=0.0, end=0.5, text="Hello"),
ASRToken(start=0.5, end=1.0, text=" world"),
ASRToken(start=1.0, end=1.5, text=" test."),
]
@pytest.fixture
def sample_silence():
"""A completed silence event."""
s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True)
s.compute_duration()
return s
@pytest.fixture
def mock_args():
"""Minimal args namespace for AudioProcessor tests."""
return SimpleNamespace(
diarization=False,
transcription=True,
target_language="",
vac=False,
vac_chunk_size=0.04,
min_chunk_size=0.1,
pcm_input=True,
punctuation_split=False,
backend="faster-whisper",
backend_policy="localagreement",
vad=True,
)
@pytest.fixture
def ground_truth_en():
"""Ground truth transcript for the 7s English audio (if available)."""
path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json"
if path.exists():
with open(path) as f:
return json.load(f)
return None

View File

@@ -0,0 +1,209 @@
"""Tests for AudioProcessor pipeline with mocked ASR backends.
These tests verify the async audio processing pipeline works correctly
without requiring any real ASR models to be loaded.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
import numpy as np
import pytest
from whisperlivekit.timed_objects import ASRToken, Transcript
# ---------------------------------------------------------------------------
# Mock ASR components
# ---------------------------------------------------------------------------
class MockASR:
"""Mock ASR model holder."""
sep = " "
SAMPLING_RATE = 16000
def __init__(self):
self.transcribe_kargs = {}
self.original_language = "en"
self.backend_choice = "mock"
def transcribe(self, audio):
return None
class MockOnlineProcessor:
"""Mock online processor that returns canned tokens."""
SAMPLING_RATE = 16000
def __init__(self, asr=None):
self.asr = asr or MockASR()
self.audio_buffer = np.array([], dtype=np.float32)
self.end = 0.0
self._call_count = 0
self._finished = False
def insert_audio_chunk(self, audio, audio_stream_end_time):
self.audio_buffer = np.append(self.audio_buffer, audio)
self.end = audio_stream_end_time
def process_iter(self, is_last=False):
self._call_count += 1
# Emit a token on every call when we have audio
if len(self.audio_buffer) > 0:
t = self._call_count * 0.5
return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end
return [], self.end
def get_buffer(self):
return Transcript(start=None, end=None, text="")
def start_silence(self):
return [], self.end
def end_silence(self, silence_duration, offset):
pass
def new_speaker(self, change_speaker):
pass
def finish(self):
self._finished = True
return [], self.end
def warmup(self, audio, init_prompt=""):
pass
def _make_pcm_bytes(duration_s=0.1, sample_rate=16000):
"""Generate silent PCM s16le bytes."""
n_samples = int(duration_s * sample_rate)
audio = np.zeros(n_samples, dtype=np.float32)
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_engine():
"""Create a mock TranscriptionEngine-like object."""
engine = SimpleNamespace(
asr=MockASR(),
diarization_model=None,
translation_model=None,
args=SimpleNamespace(
diarization=False,
transcription=True,
target_language="",
vac=False,
vac_chunk_size=0.04,
min_chunk_size=0.1,
pcm_input=True,
punctuation_split=False,
backend="mock",
backend_policy="localagreement",
vad=True,
model_size="base",
lan="en",
),
)
return engine
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestPCMConversion:
"""Test PCM byte conversion without needing the full pipeline."""
def test_s16le_roundtrip(self):
"""Convert float32 → s16le → float32 and verify approximate roundtrip."""
original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32)
s16 = (original * 32768).clip(-32768, 32767).astype(np.int16)
pcm_bytes = s16.tobytes()
# Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float)
recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
np.testing.assert_allclose(recovered, original, atol=1 / 32768)
@pytest.mark.asyncio
class TestPipelineBasics:
async def test_feed_audio_and_get_responses(self, mock_engine):
"""Feed audio through the pipeline and verify we get responses."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
# Feed 2 seconds of audio in 100ms chunks
for _ in range(20):
await processor.process_audio(_make_pcm_bytes(0.1))
# Signal EOF
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
# We should have gotten at least one response
assert len(responses) > 0
async def test_eof_terminates_pipeline(self, mock_engine):
"""Sending None (EOF) should cleanly terminate the pipeline."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
# Send a small amount of audio then EOF
await processor.process_audio(_make_pcm_bytes(0.5))
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
# Pipeline should have terminated without error
assert task.done()
async def test_empty_audio_no_crash(self, mock_engine):
"""Sending EOF immediately (no audio) should not crash."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
assert task.done()

99
tests/test_config.py Normal file
View File

@@ -0,0 +1,99 @@
"""Tests for WhisperLiveKitConfig."""
import logging
from types import SimpleNamespace
import pytest
from whisperlivekit.config import WhisperLiveKitConfig
class TestDefaults:
def test_default_backend(self):
c = WhisperLiveKitConfig()
assert c.backend == "auto"
def test_default_policy(self):
c = WhisperLiveKitConfig()
assert c.backend_policy == "simulstreaming"
def test_default_language(self):
c = WhisperLiveKitConfig()
assert c.lan == "auto"
def test_default_vac(self):
c = WhisperLiveKitConfig()
assert c.vac is True
def test_default_model_size(self):
c = WhisperLiveKitConfig()
assert c.model_size == "base"
def test_default_transcription(self):
c = WhisperLiveKitConfig()
assert c.transcription is True
assert c.diarization is False
class TestPostInit:
def test_en_model_forces_english(self):
c = WhisperLiveKitConfig(model_size="tiny.en")
assert c.lan == "en"
def test_en_suffix_with_auto_language(self):
c = WhisperLiveKitConfig(model_size="base.en", lan="auto")
assert c.lan == "en"
def test_non_en_model_keeps_language(self):
c = WhisperLiveKitConfig(model_size="base", lan="fr")
assert c.lan == "fr"
def test_policy_alias_1(self):
c = WhisperLiveKitConfig(backend_policy="1")
assert c.backend_policy == "simulstreaming"
def test_policy_alias_2(self):
c = WhisperLiveKitConfig(backend_policy="2")
assert c.backend_policy == "localagreement"
def test_policy_no_alias(self):
c = WhisperLiveKitConfig(backend_policy="localagreement")
assert c.backend_policy == "localagreement"
class TestFromNamespace:
def test_known_keys(self):
ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.backend == "faster-whisper"
assert c.lan == "en"
assert c.model_size == "large-v3"
def test_ignores_unknown_keys(self):
ns = SimpleNamespace(backend="auto", unknown_key="value", another="x")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.backend == "auto"
assert not hasattr(c, "unknown_key")
def test_preserves_defaults_for_missing(self):
ns = SimpleNamespace(backend="voxtral-mlx")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.lan == "auto"
assert c.vac is True
class TestFromKwargs:
def test_known_keys(self):
c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr")
assert c.backend == "mlx-whisper"
assert c.lan == "fr"
def test_warns_on_unknown_keys(self, caplog):
with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"):
c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value")
assert c.backend == "auto"
assert "bogus" in caplog.text
def test_post_init_runs(self):
c = WhisperLiveKitConfig.from_kwargs(model_size="small.en")
assert c.lan == "en"

View File

@@ -0,0 +1,172 @@
"""Tests for HypothesisBuffer — the core of LocalAgreement policy."""
import pytest
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.local_agreement.online_asr import HypothesisBuffer
def make_tokens(words, start=0.0, step=0.5):
"""Helper: create ASRToken list from word strings."""
tokens = []
t = start
for w in words:
tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9))
t += step
return tokens
class TestInsert:
def test_basic_insert(self):
buf = HypothesisBuffer()
tokens = make_tokens(["hello", "world"])
buf.insert(tokens, offset=0.0)
assert len(buf.new) == 2
assert buf.new[0].text == "hello"
def test_insert_with_offset(self):
buf = HypothesisBuffer()
tokens = make_tokens(["hello"], start=0.0)
buf.insert(tokens, offset=5.0)
assert buf.new[0].start == pytest.approx(5.0)
def test_insert_filters_old_tokens(self):
buf = HypothesisBuffer()
buf.last_committed_time = 10.0
tokens = make_tokens(["old", "new"], start=5.0, step=3.0)
buf.insert(tokens, offset=0.0)
# "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered
# "new" at 8.0 is also before 9.9 → filtered
assert len(buf.new) == 0
def test_insert_deduplicates_committed(self):
buf = HypothesisBuffer()
# Commit "hello"
tokens1 = make_tokens(["hello", "world"])
buf.insert(tokens1, offset=0.0)
buf.flush() # commits "hello" (buffer was empty, so nothing matches)
# Actually with empty buffer, flush won't commit anything
# Let's do it properly: two rounds
buf2 = HypothesisBuffer()
first = make_tokens(["hello", "world"])
buf2.insert(first, offset=0.0)
buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"]
second = make_tokens(["hello", "world", "test"])
buf2.insert(second, offset=0.0)
committed = buf2.flush()
# LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"]
assert len(committed) == 2
assert committed[0].text == "hello"
assert committed[1].text == "world"
class TestFlush:
def test_flush_empty(self):
buf = HypothesisBuffer()
committed = buf.flush()
assert committed == []
def test_flush_lcp_matching(self):
buf = HypothesisBuffer()
# Round 1: establish buffer
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush() # buffer = ["hello", "world"], committed = []
# Round 2: same prefix, new suffix
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
committed = buf.flush()
assert [t.text for t in committed] == ["hello", "world"]
def test_flush_no_match(self):
buf = HypothesisBuffer()
# Round 1
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush()
# Round 2: completely different
buf.insert(make_tokens(["foo", "bar"]), offset=0.0)
committed = buf.flush()
assert committed == []
def test_flush_partial_match(self):
buf = HypothesisBuffer()
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
buf.flush()
buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0)
committed = buf.flush()
assert len(committed) == 1
assert committed[0].text == "hello"
def test_flush_updates_last_committed(self):
buf = HypothesisBuffer()
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush()
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
buf.flush()
assert buf.last_committed_word == "world"
assert buf.last_committed_time > 0
def test_flush_with_confidence_validation(self):
buf = HypothesisBuffer(confidence_validation=True)
high_conf = [
ASRToken(start=0.0, end=0.5, text="sure", probability=0.99),
ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5),
]
buf.insert(high_conf, offset=0.0)
committed = buf.flush()
# "sure" has p>0.95 → committed immediately
assert len(committed) == 1
assert committed[0].text == "sure"
class TestPopCommitted:
def test_pop_removes_old(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0)
# "a": end=1.0, "b": end=2.0, "c": end=3.0
# pop_committed removes tokens with end <= time
buf.pop_committed(2.0)
# "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains
assert len(buf.committed_in_buffer) == 1
assert buf.committed_in_buffer[0].text == "c"
def test_pop_nothing(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0)
buf.pop_committed(0.0)
assert len(buf.committed_in_buffer) == 2
def test_pop_all(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5)
buf.pop_committed(100.0)
assert len(buf.committed_in_buffer) == 0
class TestStreamingSimulation:
"""Multi-round insert/flush simulating real streaming behavior."""
def test_three_rounds(self):
buf = HypothesisBuffer()
all_committed = []
# Round 1: "this is"
buf.insert(make_tokens(["this", "is"]), offset=0.0)
all_committed.extend(buf.flush())
# Round 2: "this is a test"
buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0)
all_committed.extend(buf.flush())
# Round 3: "this is a test today"
buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0)
all_committed.extend(buf.flush())
words = [t.text for t in all_committed]
assert "this" in words
assert "is" in words
assert "a" in words
assert "test" in words

183
tests/test_metrics.py Normal file
View File

@@ -0,0 +1,183 @@
"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
import pytest
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
class TestNormalizeText:
def test_lowercase(self):
assert normalize_text("Hello World") == "hello world"
def test_strip_punctuation(self):
assert normalize_text("Hello, world!") == "hello world"
def test_collapse_whitespace(self):
assert normalize_text(" hello world ") == "hello world"
def test_keep_hyphens(self):
assert normalize_text("real-time") == "real-time"
def test_keep_apostrophes(self):
assert normalize_text("don't") == "don't"
def test_unicode_normalized(self):
# e + combining accent should be same as precomposed
assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
def test_empty(self):
assert normalize_text("") == ""
def test_only_punctuation(self):
assert normalize_text("...!?") == ""
class TestComputeWER:
def test_perfect_match(self):
result = compute_wer("hello world", "hello world")
assert result["wer"] == 0.0
assert result["substitutions"] == 0
assert result["insertions"] == 0
assert result["deletions"] == 0
def test_case_insensitive(self):
result = compute_wer("Hello World", "hello world")
assert result["wer"] == 0.0
def test_punctuation_ignored(self):
result = compute_wer("Hello, world!", "hello world")
assert result["wer"] == 0.0
def test_one_substitution(self):
result = compute_wer("hello world", "hello earth")
assert result["wer"] == pytest.approx(0.5)
assert result["substitutions"] == 1
def test_one_insertion(self):
result = compute_wer("hello world", "hello big world")
assert result["wer"] == pytest.approx(0.5)
assert result["insertions"] == 1
def test_one_deletion(self):
result = compute_wer("hello big world", "hello world")
assert result["wer"] == pytest.approx(1 / 3)
assert result["deletions"] == 1
def test_completely_different(self):
result = compute_wer("the cat sat", "a dog ran")
assert result["wer"] == pytest.approx(1.0)
def test_empty_reference(self):
result = compute_wer("", "hello")
assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
assert result["ref_words"] == 0
def test_empty_hypothesis(self):
result = compute_wer("hello world", "")
assert result["wer"] == pytest.approx(1.0)
assert result["deletions"] == 2
def test_both_empty(self):
result = compute_wer("", "")
assert result["wer"] == 0.0
def test_ref_and_hyp_word_counts(self):
result = compute_wer("one two three", "one two three four")
assert result["ref_words"] == 3
assert result["hyp_words"] == 4
class TestComputeTimestampAccuracy:
def test_perfect_match(self):
words = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
]
result = compute_timestamp_accuracy(words, words)
assert result["mae_start"] == 0.0
assert result["max_delta_start"] == 0.0
assert result["n_matched"] == 2
def test_constant_offset(self):
ref = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
]
pred = [
{"word": "hello", "start": 0.1, "end": 0.6},
{"word": "world", "start": 0.6, "end": 1.1},
]
result = compute_timestamp_accuracy(pred, ref)
assert result["mae_start"] == pytest.approx(0.1)
assert result["max_delta_start"] == pytest.approx(0.1)
assert result["n_matched"] == 2
def test_mismatched_word_counts(self):
ref = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "beautiful", "start": 0.5, "end": 1.0},
{"word": "world", "start": 1.0, "end": 1.5},
]
pred = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 1.1, "end": 1.6},
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 2
assert result["n_ref"] == 3
assert result["n_pred"] == 2
def test_empty_predicted(self):
ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
result = compute_timestamp_accuracy([], ref)
assert result["mae_start"] is None
assert result["n_matched"] == 0
def test_empty_reference(self):
pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
result = compute_timestamp_accuracy(pred, [])
assert result["mae_start"] is None
assert result["n_matched"] == 0
def test_case_insensitive_matching(self):
ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 1
assert result["mae_start"] == pytest.approx(0.1)
def test_median_even_count(self):
"""Median with even number of matched words should average the two middle values."""
ref = [
{"word": "a", "start": 0.0, "end": 0.2},
{"word": "b", "start": 0.5, "end": 0.7},
{"word": "c", "start": 1.0, "end": 1.2},
{"word": "d", "start": 1.5, "end": 1.7},
]
pred = [
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
{"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
{"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
{"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 4
# sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
assert result["median_delta_start"] == pytest.approx(0.25)
def test_median_odd_count(self):
"""Median with odd number of matched words takes the middle value."""
ref = [
{"word": "a", "start": 0.0, "end": 0.2},
{"word": "b", "start": 0.5, "end": 0.7},
{"word": "c", "start": 1.0, "end": 1.2},
]
pred = [
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
{"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
{"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 3
# sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
assert result["median_delta_start"] == pytest.approx(0.2)

View File

@@ -0,0 +1,99 @@
"""Tests for silence handling — state machine and double-counting regression."""
import pytest
from whisperlivekit.timed_objects import Silence
class TestSilenceStateMachine:
"""Test Silence object state transitions."""
def test_initial_state(self):
s = Silence(start=1.0, is_starting=True)
assert s.is_starting is True
assert s.has_ended is False
assert s.duration is None
assert s.end is None
def test_end_silence(self):
s = Silence(start=1.0, is_starting=True)
s.end = 3.0
s.is_starting = False
s.has_ended = True
s.compute_duration()
assert s.duration == pytest.approx(2.0)
def test_very_short_silence(self):
s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True)
s.compute_duration()
assert s.duration == pytest.approx(0.01)
def test_zero_duration_silence(self):
s = Silence(start=5.0, end=5.0)
s.compute_duration()
assert s.duration == pytest.approx(0.0)
class TestSilenceDoubleCounting:
"""Regression tests for the silence double-counting bug.
The bug: _begin_silence and _end_silence both pushed self.current_silence
to the queue. Since they were the same Python object, _end_silence's mutation
affected the already-queued start event. The consumer processed both as
ended silences, doubling the duration.
Fix: _begin_silence now pushes a separate Silence object for the start event.
"""
def test_start_and_end_are_separate_objects(self):
"""Simulate the fix: start event and end event must be different objects."""
# Simulate _begin_silence: creates start event as separate object
current_silence = Silence(start=1.0, is_starting=True)
start_event = Silence(start=1.0, is_starting=True) # separate copy
# Simulate _end_silence: mutates current_silence
current_silence.end = 3.0
current_silence.is_starting = False
current_silence.has_ended = True
current_silence.compute_duration()
# start_event should NOT be affected by mutations to current_silence
assert start_event.is_starting is True
assert start_event.has_ended is False
assert start_event.end is None
# current_silence (end event) has the final state
assert current_silence.has_ended is True
assert current_silence.duration == pytest.approx(2.0)
def test_single_object_would_cause_double_counting(self):
"""Demonstrate the bug: if same object is used for both events."""
shared = Silence(start=1.0, is_starting=True)
queue = [shared] # start event queued
# Mutate (simulates _end_silence)
shared.end = 3.0
shared.is_starting = False
shared.has_ended = True
shared.compute_duration()
queue.append(shared) # end event queued
# Both queue items point to the SAME mutated object
assert queue[0] is queue[1] # same reference
assert queue[0].has_ended is True # start event also shows ended!
# This would cause double-counting: both items have has_ended=True
# and duration=2.0, so the consumer adds 2.0 twice = 4.0
class TestConsecutiveSilences:
def test_multiple_silences(self):
"""Multiple silence periods should have independent durations."""
s1 = Silence(start=1.0, end=2.0)
s1.compute_duration()
s2 = Silence(start=5.0, end=8.0)
s2.compute_duration()
assert s1.duration == pytest.approx(1.0)
assert s2.duration == pytest.approx(3.0)
# Total silence should be sum, not accumulated on single object
assert s1.duration + s2.duration == pytest.approx(4.0)

185
tests/test_timed_objects.py Normal file
View File

@@ -0,0 +1,185 @@
"""Tests for whisperlivekit.timed_objects data classes."""
import pytest
from whisperlivekit.timed_objects import (
ASRToken,
FrontData,
Segment,
Silence,
TimedText,
Transcript,
format_time,
)
class TestFormatTime:
def test_zero(self):
assert format_time(0) == "0:00:00"
def test_one_minute(self):
assert format_time(60) == "0:01:00"
def test_one_hour(self):
assert format_time(3600) == "1:00:00"
def test_fractional_truncated(self):
assert format_time(61.9) == "0:01:01"
class TestASRToken:
def test_with_offset(self):
t = ASRToken(start=1.0, end=2.0, text="hello")
shifted = t.with_offset(0.5)
assert shifted.start == pytest.approx(1.5)
assert shifted.end == pytest.approx(2.5)
assert shifted.text == "hello"
def test_with_offset_preserves_fields(self):
t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95)
shifted = t.with_offset(1.0)
assert shifted.speaker == 2
assert shifted.probability == 0.95
def test_is_silence_false(self):
t = ASRToken(start=0.0, end=1.0, text="hello")
assert t.is_silence() is False
def test_bool_truthy(self):
t = ASRToken(start=0.0, end=1.0, text="hello")
assert bool(t) is True
def test_bool_falsy(self):
t = ASRToken(start=0.0, end=1.0, text="")
assert bool(t) is False
class TestTimedText:
def test_has_punctuation_period(self):
t = TimedText(text="hello.")
assert t.has_punctuation() is True
def test_has_punctuation_exclamation(self):
t = TimedText(text="wow!")
assert t.has_punctuation() is True
def test_has_punctuation_question(self):
t = TimedText(text="really?")
assert t.has_punctuation() is True
def test_has_punctuation_cjk(self):
t = TimedText(text="hello。")
assert t.has_punctuation() is True
def test_no_punctuation(self):
t = TimedText(text="hello world")
assert t.has_punctuation() is False
def test_duration(self):
t = TimedText(start=1.0, end=3.5)
assert t.duration() == pytest.approx(2.5)
def test_contains_timespan(self):
outer = TimedText(start=0.0, end=5.0)
inner = TimedText(start=1.0, end=3.0)
assert outer.contains_timespan(inner) is True
assert inner.contains_timespan(outer) is False
class TestSilence:
def test_compute_duration(self):
s = Silence(start=1.0, end=3.5)
d = s.compute_duration()
assert d == pytest.approx(2.5)
assert s.duration == pytest.approx(2.5)
def test_compute_duration_none_start(self):
s = Silence(start=None, end=3.5)
d = s.compute_duration()
assert d is None
def test_compute_duration_none_end(self):
s = Silence(start=1.0, end=None)
d = s.compute_duration()
assert d is None
def test_is_silence_true(self):
s = Silence()
assert s.is_silence() is True
class TestTranscript:
def test_from_tokens(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, sep="")
assert t.text == "Hello world test."
assert t.start == pytest.approx(0.0)
assert t.end == pytest.approx(1.5)
def test_from_tokens_with_sep(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, sep="|")
assert t.text == "Hello| world| test."
def test_from_empty_tokens(self):
t = Transcript.from_tokens([])
assert t.text == ""
assert t.start is None
assert t.end is None
def test_from_tokens_with_offset(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, offset=10.0)
assert t.start == pytest.approx(10.0)
assert t.end == pytest.approx(11.5)
class TestSegment:
def test_from_tokens(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
assert seg is not None
assert seg.text == "Hello world test."
assert seg.start == pytest.approx(0.0)
assert seg.end == pytest.approx(1.5)
assert seg.speaker == -1
def test_from_silence_tokens(self):
silences = [
Silence(start=1.0, end=2.0),
Silence(start=2.0, end=3.0),
]
seg = Segment.from_tokens(silences, is_silence=True)
assert seg is not None
assert seg.speaker == -2
assert seg.is_silence() is True
assert seg.text is None
def test_from_empty_tokens(self):
seg = Segment.from_tokens([])
assert seg is None
def test_to_dict(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
d = seg.to_dict()
assert "text" in d
assert "speaker" in d
assert "start" in d
assert "end" in d
class TestFrontData:
def test_to_dict_empty(self):
fd = FrontData()
d = fd.to_dict()
assert d["lines"] == []
assert d["buffer_transcription"] == ""
assert "error" not in d
def test_to_dict_with_error(self):
fd = FrontData(error="something broke")
d = fd.to_dict()
assert d["error"] == "something broke"
def test_to_dict_with_lines(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
fd = FrontData(lines=[seg])
d = fd.to_dict()
assert len(d["lines"]) == 1
assert d["lines"][0]["text"] == "Hello world test."

6575
uv.lock generated Normal file

File diff suppressed because one or more lines are too long

View File

@@ -9,6 +9,7 @@ import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.metrics_collector import SessionMetrics
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
@@ -118,6 +119,7 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.metrics: SessionMetrics = SessionMetrics()
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
@@ -139,25 +141,43 @@ class AudioProcessor:
if self.translation_queue:
await self.translation_queue.put(self.current_silence)
async def _begin_silence(self) -> None:
async def _begin_silence(self, at_sample: Optional[int] = None) -> None:
if self.current_silence:
return
now = time() - self.beg_loop
# Use audio stream time (sample-precise) for accurate silence duration
if at_sample is not None:
audio_t = at_sample / self.sample_rate
else:
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
self.current_silence = Silence(
is_starting=True, start=now
is_starting=True, start=audio_t
)
await self._push_silence_event()
# Push a separate start-only event so _end_silence won't mutate it
start_event = Silence(is_starting=True, start=audio_t)
if self.transcription_queue:
await self.transcription_queue.put(start_event)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(start_event)
if self.translation_queue:
await self.translation_queue.put(start_event)
async def _end_silence(self) -> None:
async def _end_silence(self, at_sample: Optional[int] = None) -> None:
if not self.current_silence:
return
now = time() - self.beg_loop
self.current_silence.end = now
self.current_silence.is_starting=False
self.current_silence.has_ended=True
if at_sample is not None:
audio_t = at_sample / self.sample_rate
else:
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
self.current_silence.end = audio_t
self.current_silence.is_starting = False
self.current_silence.has_ended = True
self.current_silence.compute_duration()
self.metrics.n_silence_events += 1
if self.current_silence.duration is not None:
self.metrics.total_silence_duration_s += self.current_silence.duration
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
# Push the completed silence as the end event (separate from the start event)
await self._push_silence_event()
self.current_silence = None
@@ -253,6 +273,34 @@ class AudioProcessor:
if self.translation:
await self.translation_queue.put(SENTINEL)
async def _finish_transcription(self) -> None:
"""Call finish() on the online processor to flush remaining tokens."""
if not self.transcription:
return
try:
if hasattr(self.transcription, 'finish'):
final_tokens, end_time = await asyncio.to_thread(self.transcription.finish)
else:
# SimulStreamingOnlineProcessor uses start_silence() → process_iter(is_last=True)
final_tokens, end_time = await asyncio.to_thread(self.transcription.start_silence)
final_tokens = final_tokens or []
if final_tokens:
logger.info(f"Finish flushed {len(final_tokens)} tokens")
_buffer_transcript = self.transcription.get_buffer()
async with self.lock:
self.state.tokens.extend(final_tokens)
self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(self.state.end_buffer, end_time)
self.state.new_tokens.extend(final_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue:
for token in final_tokens:
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Error finishing transcription: {e}")
logger.debug(f"Traceback: {traceback.format_exc()}")
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
@@ -263,6 +311,7 @@ class AudioProcessor:
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
await self._finish_transcription()
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
@@ -297,8 +346,13 @@ class AudioProcessor:
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
_t0 = time()
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
_dur = time() - _t0
self.metrics.transcription_durations.append(_dur)
self.metrics.n_transcription_calls += 1
new_tokens = new_tokens or []
self.metrics.n_tokens_produced += len(new_tokens)
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
@@ -433,6 +487,7 @@ class AudioProcessor:
should_push = (response != self.last_response_content)
if should_push:
self.metrics.n_responses_sent += 1
yield response
self.last_response_content = response
@@ -535,6 +590,10 @@ class AudioProcessor:
logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.diarization:
self.diarization.close()
# Finalize session metrics
self.metrics.total_audio_duration_s = self.total_pcm_samples / self.sample_rate
self.metrics.log_summary()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self) -> bool:
@@ -553,6 +612,7 @@ class AudioProcessor:
if not self.beg_loop:
self.beg_loop = time()
self.metrics.session_start = self.beg_loop
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
@@ -560,6 +620,10 @@ class AudioProcessor:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
# Flush any remaining PCM data before signaling end-of-stream
if self.is_pcm_input and self.pcm_buffer:
await self._flush_remaining_pcm()
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
@@ -572,6 +636,8 @@ class AudioProcessor:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return
self.metrics.n_chunks_received += 1
if self.is_pcm_input:
self.pcm_buffer.extend(message)
await self.handle_pcm_data()
@@ -588,6 +654,11 @@ class AudioProcessor:
logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self) -> None:
# Without VAC, there's no speech detector to end the initial silence.
# Clear it on the first audio chunk so audio actually gets enqueued.
if not self.args.vac and self.current_silence:
await self._end_silence()
# Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec:
return
@@ -616,7 +687,7 @@ class AudioProcessor:
if res is not None:
if "start" in res and self.current_silence:
await self._end_silence()
await self._end_silence(at_sample=res.get("start"))
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
@@ -624,7 +695,7 @@ class AudioProcessor:
)
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence()
await self._begin_silence(at_sample=res.get("end"))
if not self.current_silence:
await self._enqueue_active_audio(pcm_array)
@@ -633,3 +704,21 @@ class AudioProcessor:
if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1)
async def _flush_remaining_pcm(self) -> None:
"""Flush whatever PCM data remains in the buffer, regardless of size threshold."""
if not self.pcm_buffer:
return
aligned_size = (len(self.pcm_buffer) // self.bytes_per_sample) * self.bytes_per_sample
if aligned_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_size])
self.pcm_buffer = self.pcm_buffer[aligned_size:]
# End any active silence so the audio gets enqueued
if self.current_silence:
await self._end_silence(at_sample=self.total_pcm_samples)
await self._enqueue_active_audio(pcm_array)
self.total_pcm_samples += len(pcm_array)
logger.info(f"Flushed remaining PCM buffer: {len(pcm_array)} samples ({len(pcm_array)/self.sample_rate:.2f}s)")

View File

@@ -92,7 +92,12 @@ class TranscriptionEngine:
}
if config.transcription:
if config.backend == "voxtral":
if config.backend == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
self.tokenizer = None
self.asr = VoxtralMLXASR(**transcription_common_params)
logger.info("Using Voxtral MLX native backend")
elif config.backend == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
@@ -169,6 +174,9 @@ class TranscriptionEngine:
def online_factory(args, asr):
if getattr(args, 'backend', None) == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
return VoxtralMLXOnlineProcessor(asr)
if getattr(args, 'backend', None) == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
return VoxtralHFStreamingOnlineProcessor(asr)

156
whisperlivekit/metrics.py Normal file
View File

@@ -0,0 +1,156 @@
"""Lightweight ASR evaluation metrics — no external dependencies.
Provides WER (Word Error Rate) computation via word-level Levenshtein distance,
text normalization, and word-level timestamp accuracy metrics with greedy alignment.
"""
import re
import unicodedata
from typing import Dict, List, Optional
def normalize_text(text: str) -> str:
"""Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace."""
text = text.lower()
# Normalize unicode (e.g., accented chars to composed form)
text = unicodedata.normalize("NFC", text)
# Remove punctuation (keep letters, numbers, spaces, hyphens within words)
text = re.sub(r"[^\w\s\-']", " ", text)
# Collapse whitespace
text = re.sub(r"\s+", " ", text).strip()
return text
def compute_wer(reference: str, hypothesis: str) -> Dict:
"""Compute Word Error Rate using word-level Levenshtein edit distance.
Args:
reference: Ground truth transcription.
hypothesis: Predicted transcription.
Returns:
Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words.
WER can exceed 1.0 if there are more errors than reference words.
"""
ref_words = normalize_text(reference).split()
hyp_words = normalize_text(hypothesis).split()
n = len(ref_words)
m = len(hyp_words)
if n == 0:
return {
"wer": 0.0 if m == 0 else float(m),
"substitutions": 0,
"insertions": m,
"deletions": 0,
"ref_words": 0,
"hyp_words": m,
}
# DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions)
dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)]
for i in range(1, n + 1):
dp[i][0] = (i, 0, 0, i)
for j in range(1, m + 1):
dp[0][j] = (j, 0, j, 0)
for i in range(1, n + 1):
for j in range(1, m + 1):
if ref_words[i - 1] == hyp_words[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
sub = dp[i - 1][j - 1]
ins = dp[i][j - 1]
dele = dp[i - 1][j]
sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3])
ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3])
del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1)
dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0])
dist, subs, ins, dels = dp[n][m]
return {
"wer": dist / n,
"substitutions": subs,
"insertions": ins,
"deletions": dels,
"ref_words": n,
"hyp_words": m,
}
def compute_timestamp_accuracy(
predicted: List[Dict],
reference: List[Dict],
) -> Dict:
"""Compute timestamp accuracy by aligning predicted words to reference words.
Uses greedy left-to-right alignment on normalized text. For each matched pair,
computes the start-time delta (predicted - reference).
Args:
predicted: List of dicts with keys: word, start, end.
reference: List of dicts with keys: word, start, end.
Returns:
Dict with keys: mae_start, max_delta_start, median_delta_start,
n_matched, n_ref, n_pred. Returns None values if no matches found.
"""
if not predicted or not reference:
return {
"mae_start": None,
"max_delta_start": None,
"median_delta_start": None,
"n_matched": 0,
"n_ref": len(reference),
"n_pred": len(predicted),
}
# Normalize words for matching
pred_norm = [normalize_text(p["word"]) for p in predicted]
ref_norm = [normalize_text(r["word"]) for r in reference]
# Greedy left-to-right alignment
deltas_start = []
ref_idx = 0
for p_idx, p_word in enumerate(pred_norm):
if not p_word:
continue
# Scan forward in reference to find a match (allow small skips)
search_limit = min(ref_idx + 3, len(ref_norm))
for r_idx in range(ref_idx, search_limit):
if ref_norm[r_idx] == p_word:
delta = predicted[p_idx]["start"] - reference[r_idx]["start"]
deltas_start.append(delta)
ref_idx = r_idx + 1
break
if not deltas_start:
return {
"mae_start": None,
"max_delta_start": None,
"median_delta_start": None,
"n_matched": 0,
"n_ref": len(reference),
"n_pred": len(predicted),
}
abs_deltas = [abs(d) for d in deltas_start]
sorted_abs = sorted(abs_deltas)
n = len(sorted_abs)
if n % 2 == 1:
median = sorted_abs[n // 2]
else:
median = (sorted_abs[n // 2 - 1] + sorted_abs[n // 2]) / 2
return {
"mae_start": sum(abs_deltas) / len(abs_deltas),
"max_delta_start": max(abs_deltas),
"median_delta_start": median,
"n_matched": len(deltas_start),
"n_ref": len(reference),
"n_pred": len(predicted),
}

View File

@@ -0,0 +1,84 @@
"""Lightweight runtime metrics for AudioProcessor sessions.
Zero external dependencies. Negligible overhead when not queried —
just integer increments and list appends during normal operation.
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List
logger = logging.getLogger(__name__)
@dataclass
class SessionMetrics:
"""Per-session metrics collected by AudioProcessor."""
session_start: float = 0.0
total_audio_duration_s: float = 0.0
total_processing_time_s: float = 0.0
# Chunk / call counters
n_chunks_received: int = 0
n_transcription_calls: int = 0
n_tokens_produced: int = 0
n_responses_sent: int = 0
# Per-call ASR latency (seconds)
transcription_durations: List[float] = field(default_factory=list)
# Silence
n_silence_events: int = 0
total_silence_duration_s: float = 0.0
# --- Computed properties ---
@property
def rtf(self) -> float:
"""Real-time factor: processing_time / audio_duration."""
if self.total_audio_duration_s <= 0:
return 0.0
return self.total_processing_time_s / self.total_audio_duration_s
@property
def avg_latency_ms(self) -> float:
"""Average per-call ASR latency in milliseconds."""
if not self.transcription_durations:
return 0.0
return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000
@property
def p95_latency_ms(self) -> float:
"""95th percentile per-call ASR latency in milliseconds."""
if not self.transcription_durations:
return 0.0
sorted_d = sorted(self.transcription_durations)
idx = int(len(sorted_d) * 0.95)
idx = min(idx, len(sorted_d) - 1)
return sorted_d[idx] * 1000
def to_dict(self) -> Dict:
"""Serialize to a plain dict (JSON-safe)."""
return {
"session_start": self.session_start,
"total_audio_duration_s": round(self.total_audio_duration_s, 3),
"total_processing_time_s": round(self.total_processing_time_s, 3),
"rtf": round(self.rtf, 3),
"n_chunks_received": self.n_chunks_received,
"n_transcription_calls": self.n_transcription_calls,
"n_tokens_produced": self.n_tokens_produced,
"n_responses_sent": self.n_responses_sent,
"avg_latency_ms": round(self.avg_latency_ms, 2),
"p95_latency_ms": round(self.p95_latency_ms, 2),
"n_silence_events": self.n_silence_events,
"total_silence_duration_s": round(self.total_silence_duration_s, 3),
}
def log_summary(self) -> None:
"""Emit a structured log line summarising the session."""
self.total_processing_time_s = sum(self.transcription_durations)
d = self.to_dict()
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
logger.info(f"SESSION_METRICS {d}")

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral"],
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for Voxtral streaming via HuggingFace Transformers (CUDA/CPU/MPS).",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
)
parser.add_argument(
"--no-vac",

View File

@@ -0,0 +1,6 @@
"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
from .loader import load_voxtral_model
from .model import VoxtralMLXModel
__all__ = ["load_voxtral_model", "VoxtralMLXModel"]

View File

@@ -0,0 +1,282 @@
"""
Model weight loading for the MLX Voxtral Realtime backend.
Supports two on-disk formats:
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
with optional quantisation metadata.
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
requires weight renaming and conv-weight transposition.
The public entry point is :func:`load_voxtral_model` which returns the
model, tokenizer, and raw config dict.
"""
import json
import logging
import re
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from .model import VoxtralMLXModel
logger = logging.getLogger(__name__)
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
# ---------------------------------------------------------------------------
# Downloading
# ---------------------------------------------------------------------------
_ALLOWED_PATTERNS = [
"consolidated.safetensors",
"model*.safetensors",
"model.safetensors.index.json",
"params.json",
"config.json",
"tekken.json",
]
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
"""Download model files from HuggingFace Hub and return the local path."""
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
# ---------------------------------------------------------------------------
# Weight name remapping (Mistral → our naming)
# ---------------------------------------------------------------------------
_NAME_RULES: list[tuple[str, str]] = [
# Encoder convolutions
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
# Encoder transformer blocks
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
r"encoder.blocks.\1.self_attn.q_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
r"encoder.blocks.\1.self_attn.k_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
r"encoder.blocks.\1.self_attn.v_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
r"encoder.blocks.\1.self_attn.out_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
r"encoder.blocks.\1.ffn.gate.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
r"encoder.blocks.\1.ffn.down.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
r"encoder.blocks.\1.ffn.up.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
# Adapter
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
# Decoder embedding
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
# Decoder blocks
(r"layers\.(\d+)\.attention\.wq\.weight",
r"decoder.blocks.\1.self_attn.q_proj.weight"),
(r"layers\.(\d+)\.attention\.wk\.weight",
r"decoder.blocks.\1.self_attn.k_proj.weight"),
(r"layers\.(\d+)\.attention\.wv\.weight",
r"decoder.blocks.\1.self_attn.v_proj.weight"),
(r"layers\.(\d+)\.attention\.wo\.weight",
r"decoder.blocks.\1.self_attn.out_proj.weight"),
(r"layers\.(\d+)\.attention_norm\.weight",
r"decoder.blocks.\1.pre_attn_norm.weight"),
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
r"decoder.blocks.\1.ffn.gate.weight"),
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
r"decoder.blocks.\1.ffn.down.weight"),
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
r"decoder.blocks.\1.ffn.up.weight"),
(r"layers\.(\d+)\.ffn_norm\.weight",
r"decoder.blocks.\1.pre_ffn_norm.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
# Decoder final norm
(r"norm\.weight", r"decoder.final_norm.weight"),
]
_PREFIX_STRIP = re.compile(
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
)
def _translate_weight_name(name: str) -> str | None:
name = _PREFIX_STRIP.sub("", name)
for pattern, replacement in _NAME_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
return result
return None
def _is_conv_weight(name: str) -> bool:
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
# ---------------------------------------------------------------------------
# Converted-format weight remapping (voxmlx names → our names)
# ---------------------------------------------------------------------------
_CONVERTED_RULES: list[tuple[str, str]] = [
# Adapter
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
# Encoder transformer blocks
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
# Decoder embedding
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
# Decoder blocks
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
]
# Also remap o_proj → out_proj in both encoder and decoder
_POST_RENAME = [
(r"\.o_proj\.", r".out_proj."),
]
def _remap_converted_name(name: str) -> str:
"""Translate a converted-format weight name to our naming convention."""
for pattern, replacement in _CONVERTED_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
name = result
break
for pattern, replacement in _POST_RENAME:
name = re.sub(pattern, replacement, name)
return name
# ---------------------------------------------------------------------------
# Loading strategies
# ---------------------------------------------------------------------------
def _has_converted_layout(path: Path) -> bool:
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
def _load_converted_weights(path: Path):
with open(path / "config.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
quant = config.get("quantization")
if quant is not None:
gs = quant["group_size"]
nn.quantize(
model,
group_size=gs,
bits=quant["bits"],
class_predicate=lambda _p, m: (
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
),
)
index_file = path / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
shard_map = json.load(f)
shard_files = sorted(set(shard_map["weight_map"].values()))
weights = {}
for sf in shard_files:
weights.update(mx.load(str(path / sf)))
else:
weights = mx.load(str(path / "model.safetensors"))
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
model.load_weights(list(remapped.items()))
mx.eval(model.parameters())
return model, config
def _load_original_weights(path: Path):
with open(path / "params.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
raw = mx.load(str(path / "consolidated.safetensors"))
mapped: dict[str, mx.array] = {}
skipped: list[str] = []
for name, tensor in raw.items():
if name == "output.weight":
continue
new_name = _translate_weight_name(name)
if new_name is None:
skipped.append(name)
continue
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
if _is_conv_weight(new_name):
tensor = mx.swapaxes(tensor, 1, 2)
mapped[new_name] = tensor
if skipped:
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
model.load_weights(list(mapped.items()))
mx.eval(model.parameters())
return model, config
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def _load_tokenizer(model_dir: Path):
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
"""Load a Voxtral Realtime model and its tokenizer.
Args:
path_or_id: Local directory path **or** a HuggingFace model ID.
Returns:
``(model, tokenizer, config)``
"""
p = Path(path_or_id)
if not p.exists():
p = download_weights(path_or_id)
if _has_converted_layout(p):
model, config = _load_converted_weights(p)
else:
model, config = _load_original_weights(p)
tokenizer = _load_tokenizer(p)
logger.info("Voxtral MLX model loaded from %s", p)
return model, tokenizer, config

View File

@@ -0,0 +1,534 @@
"""
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
Architecture:
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
with DelayEmbedding providing time-conditioning to the decoder.
The model supports both batch inference (full audio) and incremental streaming
(one chunk at a time with cached encoder/decoder state).
"""
import math
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class SlidingKVCache:
"""Bounded key-value cache with rotating buffer for sliding-window attention.
Uses in-place writes for single-token autoregressive steps and
concatenation for multi-token prefills. Pre-allocates in blocks of
``alloc_step`` entries to reduce repeated allocation.
"""
alloc_step = 256
def __init__(self, capacity: int):
self.capacity = capacity
self.keys = None
self.values = None
self._offset = 0
self._write_idx = 0
@property
def offset(self) -> int:
return self._offset
# -- helpers --
def _reorder(self, buf):
"""Return *buf* in temporal order (unwrap the circular buffer)."""
if self._write_idx == buf.shape[2]:
return buf
if self._write_idx < self._offset:
return mx.concatenate(
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
axis=2,
)
return buf[..., : self._write_idx, :]
def _drop_oldest(self, buf, n_drop, tail=None):
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
if tail is not None:
parts.append(tail)
return mx.concatenate(parts, axis=2)
# -- update strategies --
def _append_concat(self, k, v):
"""Multi-token update via concatenation (used during prefill)."""
if self.keys is None:
self.keys, self.values = k, v
else:
self.keys = self._reorder(self.keys)
self.values = self._reorder(self.values)
self._write_idx = self.keys.shape[2]
overflow = self._write_idx - self.capacity + 1
self.keys = self._drop_oldest(self.keys, overflow, k)
self.values = self._drop_oldest(self.values, overflow, v)
self._offset += k.shape[2]
self._write_idx = self.keys.shape[2]
return self.keys, self.values
def _write_inplace(self, k, v):
"""Single-token update via in-place write (autoregressive step)."""
B, n_heads, S, dim_k = k.shape
dim_v = v.shape[3]
prev = self._offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
):
n_new = min(self.alloc_step, self.capacity - prev)
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
self.values = mx.concatenate([self.values, fresh_v], axis=2)
else:
self.keys, self.values = fresh_k, fresh_v
self._write_idx = prev
overflow = self.keys.shape[2] - self.capacity
if overflow > 0:
self.keys = self._drop_oldest(self.keys, overflow)
self.values = self._drop_oldest(self.values, overflow)
self._write_idx = self.capacity
if self._write_idx == self.capacity:
self._write_idx = 0
self.keys[..., self._write_idx : self._write_idx + S, :] = k
self.values[..., self._write_idx : self._write_idx + S, :] = v
self._offset += S
self._write_idx += S
if self._offset < self.capacity:
return (
self.keys[..., : self._offset, :],
self.values[..., : self._offset, :],
)
return self.keys, self.values
# -- public API --
def update_and_fetch(self, k, v):
if k.shape[2] == 1:
return self._write_inplace(k, v)
return self._append_concat(k, v)
# ---------------------------------------------------------------------------
# Encoder components
# ---------------------------------------------------------------------------
class CausalConv(nn.Module):
"""1-D causal convolution (left-padded so no future leakage)."""
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
super().__init__()
self.stride = stride
self.kernel = kernel
self.left_pad = kernel - stride
self.weight = mx.zeros((channels_out, kernel, channels_in))
self.bias = mx.zeros((channels_out,))
def __call__(self, x: mx.array) -> mx.array:
if self.left_pad > 0:
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
class _EncoderSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope_theta = rope_theta
def __call__(self, x, mask, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _EncoderFFN(nn.Module):
"""SwiGLU feed-forward for encoder layers."""
def __init__(self, dim: int, hidden: int):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=True)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class _EncoderBlock(nn.Module):
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _EncoderFFN(dim, hidden)
def __call__(self, x, mask, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
x = x + self.ffn(self.pre_ffn_norm(x))
return x
class StreamingEncoder(nn.Module):
"""Causal Whisper-style encoder with two causal convolutions followed by
a stack of transformer blocks. Supports both full-sequence and
incremental (streaming) forward passes."""
def __init__(
self,
mel_channels: int = 128,
dim: int = 1280,
n_layers: int = 32,
n_heads: int = 32,
head_dim: int = 64,
hidden_dim: int = 5120,
rope_theta: float = 1e6,
sliding_window: int = 750,
):
super().__init__()
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
self.blocks = [
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
self.sliding_window = sliding_window
# -- full-sequence --
def _apply_convs(self, mel: mx.array) -> mx.array:
x = mel.T[None, :, :] # [1, T, mel_channels]
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
return x
def forward(self, mel: mx.array) -> mx.array:
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
for blk in self.blocks:
x = blk(x, mask="causal")
return self.final_norm(x)
# -- incremental (streaming) --
def forward_conv_incremental(self, x_in, tail1, tail2):
"""Process new mel frames through the two causal convs using cached tails.
Args:
x_in: [1, N, mel_channels]
tail1: [1, pad1, mel_channels] or None (first call)
tail2: [1, pad2, dim] or None (first call)
Returns:
(out, new_tail1, new_tail2)
"""
# Conv1 (kernel=3, stride=1 → left_pad=2)
if tail1 is not None:
c1_in = mx.concatenate([tail1, x_in], axis=1)
else:
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
c1_out = nn.gelu(
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
)
# Conv2 (kernel=3, stride=2 → left_pad=1)
if tail2 is not None:
c2_in = mx.concatenate([tail2, c1_out], axis=1)
else:
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
c2_out = nn.gelu(
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
)
return c2_out, new_tail1, new_tail2
def forward_transformer_incremental(self, x, cache_list):
"""Run transformer blocks with per-layer KV caches."""
for i, blk in enumerate(self.blocks):
x = blk(x, mask="causal", cache=cache_list[i])
return self.final_norm(x)
# ---------------------------------------------------------------------------
# Decoder components
# ---------------------------------------------------------------------------
class _DecoderAttention(nn.Module):
"""Grouped-query attention for the text decoder."""
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_theta = rope_theta
def __call__(self, x, mask=None, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _DecoderFFN(nn.Module):
"""SwiGLU feed-forward for decoder layers."""
def __init__(self, dim, hidden):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=False)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class AdaptiveScaling(nn.Module):
"""Small MLP that produces a multiplicative scale from the delay embedding,
used to condition the FFN on the streaming delay."""
def __init__(self, dim, bottleneck):
super().__init__()
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
def __call__(self, cond):
return self.proj_out(nn.gelu(self.proj_in(cond)))
class _DecoderBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _DecoderFFN(dim, hidden)
def __call__(self, x, delay_cond, mask=None, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
x = x + self.ffn(scaled)
return x
class TextDecoder(nn.Module):
"""Mistral-style causal language model with adaptive time-conditioning."""
def __init__(
self,
dim: int = 3072,
n_layers: int = 26,
n_heads: int = 32,
n_kv_heads: int = 8,
head_dim: int = 128,
hidden_dim: int = 9216,
vocab_size: int = 131072,
rope_theta: float = 1e6,
cond_dim: int = 32,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, dim)
self.blocks = [
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
def embed(self, token_ids: mx.array) -> mx.array:
return self.token_embedding(token_ids)
def __call__(self, x, delay_cond, mask=None, cache=None):
delay_cond = delay_cond.astype(x.dtype)
for i, blk in enumerate(self.blocks):
blk_cache = cache[i] if cache is not None else None
x = blk(x, delay_cond, mask, blk_cache)
x = self.final_norm(x)
return self.token_embedding.as_linear(x)
# ---------------------------------------------------------------------------
# Adapter & embeddings
# ---------------------------------------------------------------------------
class EncoderToDecoderAdapter(nn.Module):
"""Two-layer projection from encoder space to decoder space."""
def __init__(self, enc_dim: int, dec_dim: int):
super().__init__()
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
def __call__(self, x):
return self.linear2(nn.gelu(self.linear1(x)))
class DelayEmbedding(nn.Module):
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
vector for the decoder's adaptive scaling."""
def __init__(self, dim: int = 3072, theta: float = 10000.0):
super().__init__()
self.dim = dim
half = dim // 2
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
self._freqs = freqs
def __call__(self, delay: mx.array) -> mx.array:
t = delay.reshape(-1, 1).astype(mx.float32)
angles = t * self._freqs
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class VoxtralMLXModel(nn.Module):
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
def __init__(self, config: dict):
super().__init__()
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
audio_cfg = enc_cfg["audio_encoding_args"]
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
self.encoder = StreamingEncoder(
mel_channels=audio_cfg["num_mel_bins"],
dim=enc_cfg["dim"],
n_layers=enc_cfg["n_layers"],
n_heads=enc_cfg["n_heads"],
head_dim=enc_cfg["head_dim"],
hidden_dim=enc_cfg["hidden_dim"],
rope_theta=enc_cfg["rope_theta"],
sliding_window=enc_cfg["sliding_window"],
)
adapter_input_dim = enc_cfg["dim"] * ds_factor
decoder_dim = config["dim"]
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
self.decoder = TextDecoder(
dim=decoder_dim,
n_layers=config["n_layers"],
n_heads=config["n_heads"],
n_kv_heads=config["n_kv_heads"],
head_dim=config["head_dim"],
hidden_dim=config["hidden_dim"],
vocab_size=config["vocab_size"],
rope_theta=config["rope_theta"],
cond_dim=cond_bottleneck,
)
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
self.ds_factor = ds_factor
# -- batch encode --
def encode(self, mel: mx.array) -> mx.array:
T = mel.shape[1]
if T % 2 != 0:
mel = mel[:, 1:]
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
h = h[0]
n = h.shape[0]
trim = n % self.ds_factor
if trim:
h = h[trim:]
n = h.shape[0]
h = h.reshape(n // self.ds_factor, -1)
return self.adapter(h)
# -- incremental encode --
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
"""Incrementally encode new mel frames.
Returns:
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
"""
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
if enc_cache is None:
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
x = self.encoder.forward_transformer_incremental(x, enc_cache)
x = x[0] # [N, enc_dim]
if ds_remainder is not None:
x = mx.concatenate([ds_remainder, x])
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
if n_full == 0:
return None, conv_tail1, conv_tail2, enc_cache, x
leftover = x[n_full:] if x.shape[0] > n_full else None
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
# -- decode --
def decode(self, embeddings, delay_cond, mask=None, cache=None):
return self.decoder(embeddings, delay_cond, mask, cache)

View File

@@ -0,0 +1,202 @@
"""
Mel spectrogram computation for Voxtral Realtime.
Provides both a full-audio function and an incremental streaming variant
that maintains overlap state between calls. The DFT is computed via
matrix multiplication in MLX — no external FFT dependency required.
"""
import math
import mlx.core as mx
import numpy as np
# Audio / mel constants matching the Voxtral Realtime model expectations.
SAMPLE_RATE = 16_000
WINDOW_SIZE = 400 # n_fft
HOP = 160
MEL_BANDS = 128
MEL_MAX = 1.5 # global log-mel normalisation ceiling
# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
# Padding tokens used by the model prompt structure.
LEFT_PAD_TOKENS = 32
RIGHT_PAD_TOKENS = 17
# ---------------------------------------------------------------------------
# Slaney mel filterbank
# ---------------------------------------------------------------------------
def _build_slaney_filterbank(
sr: int = SAMPLE_RATE,
n_fft: int = WINDOW_SIZE,
n_mels: int = MEL_BANDS,
lo_hz: float = 0.0,
hi_hz: float = 8000.0,
) -> np.ndarray:
"""Compute a Slaney-normalised triangular mel filterbank.
Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
"""
def _hz2mel(f):
threshold = 1000.0
base_mel = 15.0
log_coeff = 27.0 / np.log(6.4)
mel = 3.0 * f / 200.0
if isinstance(f, np.ndarray):
above = f >= threshold
mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
elif f >= threshold:
mel = base_mel + np.log(f / threshold) * log_coeff
return mel
def _mel2hz(m):
threshold = 1000.0
base_mel = 15.0
log_coeff = np.log(6.4) / 27.0
hz = 200.0 * m / 3.0
above = m >= base_mel
hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
return hz
n_bins = n_fft // 2 + 1
fft_hz = np.linspace(0, sr / 2, n_bins)
mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = _mel2hz(mel_pts)
diffs = np.diff(hz_pts)
slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
rising = -slopes[:, :-2] / diffs[:-1]
falling = slopes[:, 2:] / diffs[1:]
fb = np.maximum(0.0, np.minimum(rising, falling))
# Slaney area normalisation
widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
fb *= np.expand_dims(widths, 0)
return fb.T.astype(np.float32)
_CACHED_FILTERS: mx.array | None = None
def _mel_filters() -> mx.array:
global _CACHED_FILTERS
if _CACHED_FILTERS is None:
_CACHED_FILTERS = mx.array(_build_slaney_filterbank())
return _CACHED_FILTERS
# ---------------------------------------------------------------------------
# DFT helpers
# ---------------------------------------------------------------------------
def _hann_window() -> mx.array:
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
def _dft_matrices():
"""Pre-compute the real / imaginary DFT basis matrices."""
n_bins = WINDOW_SIZE // 2 + 1
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
return mx.cos(phase), mx.sin(phase)
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
"""Frame *audio* using the Hann window and compute power spectrogram."""
n_bins = WINDOW_SIZE // 2 + 1
n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
if n_frames <= 0:
return mx.zeros((0, n_bins))
offsets = (mx.arange(n_frames) * HOP)[:, None]
indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
windowed = audio[indices] * window[None, :]
dft_re, dft_im = _dft_matrices()
real_part = windowed @ dft_re.T
imag_part = windowed @ dft_im.T
return real_part ** 2 + imag_part ** 2
def _apply_mel_and_log(power: mx.array) -> mx.array:
"""Convert a power spectrogram to log-mel and normalise."""
mel = power @ _mel_filters().T
log_mel = mx.log10(mx.maximum(mel, 1e-10))
log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
return (log_mel + 4.0) / 4.0
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def compute_mel(audio: np.ndarray) -> mx.array:
"""Compute log-mel spectrogram for a complete audio signal.
Args:
audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
Returns:
``[MEL_BANDS, T]`` MLX array.
"""
x = mx.array(audio)
pad = WINDOW_SIZE // 2
x = mx.pad(x, [(pad, pad)])
window = _hann_window()
power = _stft_frames(x, window)
# Drop last frame to match reference STFT behaviour
power = power[:-1]
return _apply_mel_and_log(power).T
def compute_mel_streaming(
chunk: np.ndarray,
overlap: np.ndarray | None,
) -> tuple[mx.array, np.ndarray]:
"""Incrementally compute log-mel for a new audio chunk.
Args:
chunk: New audio samples (float32 numpy).
overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
previous call, or *None* on the first call (uses zero-padding).
Returns:
``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
*new_overlap* is the 240-sample tail for the next call.
"""
tail_len = WINDOW_SIZE - HOP # 240
if overlap is not None:
combined = np.concatenate([overlap, chunk])
else:
combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
new_overlap = combined[-tail_len:].copy()
x = mx.array(combined)
window = _hann_window()
power = _stft_frames(x, window)
if power.shape[0] == 0:
return mx.zeros((MEL_BANDS, 0)), new_overlap
return _apply_mel_and_log(power).T, new_overlap
def pad_audio(
audio: np.ndarray,
n_left: int = LEFT_PAD_TOKENS,
n_right: int = RIGHT_PAD_TOKENS,
) -> np.ndarray:
"""Pad audio with silence for batch (non-streaming) inference."""
left = n_left * SAMPLES_PER_TOKEN
align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
right = align + n_right * SAMPLES_PER_TOKEN
return np.pad(audio, (left, right))

View File

@@ -0,0 +1,521 @@
"""
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
(streaming processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Unlike the HuggingFace backend, this runs the full inference loop in-process
(no background thread / queue) — MLX operations on Apple Silicon are fast
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import (
SAMPLES_PER_TOKEN,
LEFT_PAD_TOKENS,
RIGHT_PAD_TOKENS,
compute_mel_streaming,
)
logger = logging.getLogger(__name__)
# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
return ids, n_delay
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class VoxtralMLXASR:
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
it alive for the lifetime of the server."""
sep = " "
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL_ID
t0 = time.time()
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "voxtral-mlx"
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class VoxtralMLXOnlineProcessor:
"""Streaming processor that incrementally encodes audio and decodes text
using the MLX Voxtral model.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: list = []
self.audio_buffer = np.array([], dtype=np.float32)
self._model = asr.model
self._tokenizer = asr.tokenizer
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
self._prefix_len = len(self._prompt_ids)
self._delay_cond = self._model.delay_embedding(
mx.array([self._n_delay], dtype=mx.float32)
)
mx.eval(self._delay_cond)
self._prompt_embeds = self._model.decoder.embed(
mx.array([self._prompt_ids])
)[0] # [prefix_len, dim]
mx.eval(self._prompt_embeds)
self._eos_id = self._tokenizer.eos_id
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
# The streaming model has an inherent delay: text for audio at position P
# is generated at decoder position P + n_delay. Compensate timestamps.
self._delay_secs = self._n_delay * self._secs_per_token
self._reset_state()
# -- state management --
def _reset_state(self):
"""Reset all incremental state for a fresh utterance."""
# Audio accumulation
self._pending = np.zeros(0, dtype=np.float32)
# Mel overlap
self._mel_overlap: np.ndarray | None = None
# Encoder incremental state
self._conv_tail1 = None
self._conv_tail2 = None
self._enc_cache = None
self._ds_remainder = None
# Audio embeddings not yet decoded
self._audio_embeds: mx.array | None = None
# Decoder state
self._dec_cache: list[SlidingKVCache] | None = None
self._last_token: mx.array | None = None
# Bookkeeping
self._samples_encoded = 0
self._positions_decoded = 0
self._prefilled = False
self._first_chunk = True
# Text state
self._full_text = ""
self._n_text_tokens = 0
self._n_committed_words = 0
self._time_offset = 0.0
# Per-word audio position tracking: decoder position (relative to prefix)
# where each word in _full_text started and ended
self._word_audio_starts: list[int] = [] # audio pos where word i started
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending = np.append(self._pending, audio)
self.audio_buffer = self._pending
# -- core processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._step(is_last)
except Exception as e:
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# 1. Encode any new audio
self._encode_pending()
if self._audio_embeds is None:
return [], self.end
# 2. Compute how many positions we can safely decode
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0:
return [], self.end
# 3. Prefill if needed
if not self._prefilled:
if self._positions_decoded + n_available < self._prefix_len:
return [], self.end
self._do_prefill()
# Re-check after consuming prefix embeddings
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0 or self._audio_embeds is None:
return [], self.end
# 4. Decode available positions
hit_eos = self._decode_positions(n_decodable)
if hit_eos:
# Flush words, reset for next utterance
words = self._flush_all_words()
logger.debug(
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
"samples_encoded=%d (%.2fs), text='%s'",
len(words), self._samples_encoded,
self._samples_encoded / self.SAMPLING_RATE,
self._full_text[-60:] if self._full_text else "",
)
saved_offset = self._time_offset
self._reset_state()
self._time_offset = saved_offset
return words, self.end
# 5. Extract committed words (all but the last, which may still grow)
return self._extract_committed_words(), self.end
def _encode_pending(self):
"""Feed pending audio through the incremental encoder."""
available = len(self._pending)
if available < SAMPLES_PER_TOKEN:
return
if self._first_chunk:
# First chunk: prepend silence for left-padding
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
chunk = np.concatenate([left_pad, self._pending[:n_take]])
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
self._first_chunk = False
else:
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
chunk = self._pending[:n_take]
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
self._model.encode_incremental(
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
)
)
if embeds is not None:
mx.eval(embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
else:
self._audio_embeds = embeds
self.audio_buffer = self._pending
def _do_prefill(self):
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
n_dec_layers = len(self._model.decoder.blocks)
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
self._last_token = self._sample(logits)
mx.async_eval(self._last_token)
# Remove consumed prefix embeddings
self._audio_embeds = self._audio_embeds[self._prefix_len :]
if self._audio_embeds.shape[0] == 0:
self._audio_embeds = None
self._positions_decoded = self._prefix_len
self._prefilled = True
def _decode_positions(self, n: int) -> bool:
"""Autoregressively decode *n* positions. Returns True on EOS."""
base_pos = self._positions_decoded # absolute position before this batch
for i in range(n):
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
next_tok = self._sample(logits)
mx.async_eval(next_tok)
token_id = self._last_token.item()
if token_id == self._eos_id:
# Close the current word if one is being built
if self._current_word_pos is not None:
self._word_audio_ends.append(base_pos + i - self._prefix_len)
self._current_word_pos = None
self._trim_embeds(i)
self._positions_decoded += i
return True
text = self._tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
audio_pos = base_pos + i - self._prefix_len
# Detect word boundary: new word starts with space or is the very first text
if text.lstrip() != text or not self._full_text:
# Close previous word if exists
if self._current_word_pos is not None:
self._word_audio_ends.append(audio_pos)
# Start new word
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
elif self._current_word_pos is None:
# First token of first word (no leading space)
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
self._full_text += text
self._n_text_tokens += 1
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._last_token = next_tok
self._positions_decoded += n
self._trim_embeds(n)
return False
def _trim_embeds(self, n_consumed: int):
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
def _sample(self, logits: mx.array) -> mx.array:
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
# -- word extraction --
def _audio_pos_to_time(self, pos: int) -> float:
"""Convert an audio position (relative to prefix end) to seconds."""
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
"""Compute (start, end) time for a word using tracked word positions."""
starts = self._word_audio_starts
ends = self._word_audio_ends
if not starts:
return self._time_offset, self._time_offset
# Get start position for this word
if word_idx < len(starts):
t0 = self._audio_pos_to_time(starts[word_idx])
else:
# Fallback: estimate from last known position
last_pos = ends[-1] if ends else starts[-1]
t0 = self._audio_pos_to_time(last_pos + 1)
# Get end position: use the start of the next word, or the end of this word
if word_idx + 1 < len(starts):
t1 = self._audio_pos_to_time(starts[word_idx + 1])
elif word_idx < len(ends):
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
else:
# Last word, still being built: use last known position + 1 token
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
t1 = self._audio_pos_to_time(last_pos + 1)
return t0, t1
def _extract_committed_words(self) -> List[ASRToken]:
"""Return complete words (all except the last which may still grow)."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while len(words) > self._n_committed_words + 1:
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
def _flush_all_words(self) -> List[ASRToken]:
"""Flush every word including the last partial one."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while self._n_committed_words < len(words):
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
# -- interface methods --
def get_buffer(self) -> Transcript:
if not self._full_text:
return Transcript(start=None, end=None, text="")
words = self._full_text.split()
remaining = words[self._n_committed_words :]
if remaining:
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
words = self._flush_all_words()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug(
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
len(self._pending),
self._audio_embeds.shape if self._audio_embeds is not None else None,
self._samples_encoded,
self._positions_decoded,
self._prefilled,
self._full_text[-80:] if self._full_text else "",
)
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = len(self._pending) % SAMPLES_PER_TOKEN
if remainder > 0:
align_pad = SAMPLES_PER_TOKEN - remainder
else:
align_pad = 0
# Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending = np.append(
self._pending, np.zeros(total_pad, dtype=np.float32)
)
# Encode remaining audio (including right-padding)
self._encode_pending()
logger.debug(
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
self._audio_embeds.shape if self._audio_embeds is not None else None,
len(self._pending),
)
hit_eos = False
# Decode everything that's left from right-padding
if self._audio_embeds is not None and self._prefilled:
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
logger.debug(
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
hit_eos, self._full_text[-80:] if self._full_text else "",
)
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
# Check if this starts a new word
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
words = self._flush_all_words()
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
return words, self.end

View File

@@ -296,10 +296,15 @@ class Tokenizer:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
try:
replacement_char_index = decoded.index(replacement_char)
replacement_char_index += unicode_offset
except ValueError:
replacement_char_index = None
if replacement_char_index is None or (
replacement_char_index < len(decoded_full)
and decoded_full[replacement_char_index] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)