mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bc0937c46 | ||
|
|
929cf7a26b | ||
|
|
abfaf06203 | ||
|
|
d1fe932241 | ||
|
|
c112ceffb6 | ||
|
|
4917406e06 | ||
|
|
b63f54e838 | ||
|
|
c56a53fbf4 | ||
|
|
66e58624b9 | ||
|
|
9366e067f9 | ||
|
|
866c25670c | ||
|
|
2553ef283e | ||
|
|
73e7fafc48 | ||
|
|
bbcebcb1fe | ||
|
|
4bb58dc7aa | ||
|
|
27ca028479 | ||
|
|
d24805cc18 | ||
|
|
994ce21365 | ||
|
|
132823dc09 | ||
|
|
d6d8c2635f | ||
|
|
8fedeb9fed | ||
|
|
b1fc23807a | ||
|
|
10c4e5f730 | ||
|
|
c76b2ef2c6 | ||
|
|
4b2377c243 | ||
|
|
a4da246ea5 | ||
|
|
9b2c3ee844 | ||
|
|
83d0fa3fac | ||
|
|
5a12c627b4 | ||
|
|
f5eee67b11 | ||
|
|
4a6868e3e1 | ||
|
|
3c15246fc0 | ||
|
|
d337248fda | ||
|
|
b8d9d7d289 | ||
|
|
4c7706e2cf | ||
|
|
7f3a3df620 | ||
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af | ||
|
|
6ae545bcb1 | ||
|
|
04980d3f5e | ||
|
|
79a705c969 | ||
|
|
34e4abd455 | ||
|
|
d59ddbaeae | ||
|
|
4dd66e7766 | ||
|
|
3db5d81a20 | ||
|
|
b67ddea494 | ||
|
|
3192553e20 | ||
|
|
f379a243fe | ||
|
|
ec09898a9f | ||
|
|
befbae56c7 | ||
|
|
719e8b1a20 | ||
|
|
f1b47178d8 | ||
|
|
59db08e961 | ||
|
|
6fc20b9562 | ||
|
|
fac8659161 | ||
|
|
4d9332ce7d | ||
|
|
62444ce746 | ||
|
|
2431a6bf91 | ||
|
|
d1263e7228 | ||
|
|
30ddd522a4 | ||
|
|
635bace09e | ||
|
|
f1113e3eb0 | ||
|
|
cc5f819ce7 | ||
|
|
82cd24bb75 | ||
|
|
d45c397c6a | ||
|
|
45bf3f57d7 | ||
|
|
1d88ba9d69 | ||
|
|
c0965c6c31 | ||
|
|
34ddd2ac02 | ||
|
|
345d781e97 | ||
|
|
28cf831701 |
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
.git
|
||||||
|
.github
|
||||||
|
.venv
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
.pytest_cache
|
||||||
|
.mypy_cache
|
||||||
|
.ruff_cache
|
||||||
|
.cache
|
||||||
|
.tmp
|
||||||
|
.secrets
|
||||||
|
dist
|
||||||
|
build
|
||||||
61
.github/workflows/publish-docker.yml
vendored
Normal file
61
.github/workflows/publish-docker.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
name: Publish Docker Images
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Image tag to publish (without image suffix)"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- image_suffix: cpu-diarization-sortformer
|
||||||
|
dockerfile: Dockerfile.cpu
|
||||||
|
extras: cpu,diarization-sortformer
|
||||||
|
- image_suffix: cu129-diarization-sortformer
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
extras: cu129,diarization-sortformer
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set lowercase owner
|
||||||
|
id: owner
|
||||||
|
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
|
||||||
|
|
||||||
|
- name: Login to GHCR
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Setup Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build and push image
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./${{ matrix.dockerfile }}
|
||||||
|
push: true
|
||||||
|
build-args: |
|
||||||
|
EXTRAS=${{ matrix.extras }}
|
||||||
|
tags: |
|
||||||
|
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
|
||||||
|
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}
|
||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -119,9 +119,11 @@ run_*.sh
|
|||||||
*.pt
|
*.pt
|
||||||
|
|
||||||
# Debug & testing
|
# Debug & testing
|
||||||
test_*.py
|
/test_*.py
|
||||||
|
!test_backend_offline.py
|
||||||
launch.json
|
launch.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
test/*
|
/test/
|
||||||
|
!tests/
|
||||||
nllb-200-distilled-600M-ctranslate2/*
|
nllb-200-distilled-600M-ctranslate2/*
|
||||||
*.mp3
|
*.mp3
|
||||||
205
BENCHMARK.md
Normal file
205
BENCHMARK.md
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# WhisperLiveKit Benchmark Report
|
||||||
|
|
||||||
|
Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon.
|
||||||
|
All tests run through the full AudioProcessor pipeline (same code path as production WebSocket).
|
||||||
|
|
||||||
|
## Test Environment
|
||||||
|
|
||||||
|
| Property | Value |
|
||||||
|
|----------|-------|
|
||||||
|
| Hardware | Apple M4, 32 GB RAM |
|
||||||
|
| OS | macOS 25.3.0 (arm64) |
|
||||||
|
| Python | 3.13 |
|
||||||
|
| faster-whisper | 1.2.1 |
|
||||||
|
| mlx-whisper | installed (via mlx) |
|
||||||
|
| Voxtral MLX | native MLX backend |
|
||||||
|
| Voxtral (HF) | transformers-based |
|
||||||
|
| VAC (Silero VAD) | enabled unless noted |
|
||||||
|
| Chunk size | 100 ms |
|
||||||
|
| Pacing | no-realtime (as fast as possible) |
|
||||||
|
|
||||||
|
## Audio Test Files
|
||||||
|
|
||||||
|
| File | Duration | Language | Speakers | Description |
|
||||||
|
|------|----------|----------|----------|-------------|
|
||||||
|
| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses |
|
||||||
|
| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps |
|
||||||
|
| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation |
|
||||||
|
|
||||||
|
Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
### English -- Short (7.2 s, 1 speaker)
|
||||||
|
|
||||||
|
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||||
|
|---------|--------|-------|-----|-----|---------------|
|
||||||
|
| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s |
|
||||||
|
| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s |
|
||||||
|
| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s |
|
||||||
|
| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s |
|
||||||
|
| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s |
|
||||||
|
| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s |
|
||||||
|
| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s |
|
||||||
|
| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s |
|
||||||
|
| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s |
|
||||||
|
| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s |
|
||||||
|
|
||||||
|
### English -- Multi-speaker (30.0 s, 3 speakers)
|
||||||
|
|
||||||
|
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||||
|
|---------|--------|-------|-----|-----|---------------|
|
||||||
|
| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s |
|
||||||
|
| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s |
|
||||||
|
| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s |
|
||||||
|
| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s |
|
||||||
|
| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s |
|
||||||
|
| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s |
|
||||||
|
| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s |
|
||||||
|
| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s |
|
||||||
|
| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s |
|
||||||
|
| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s |
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="benchmark_chart.png" alt="Benchmark comparison on 30s English" width="800">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
### French (16.3 s, 1 speaker, `--language fr`)
|
||||||
|
|
||||||
|
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||||
|
|---------|--------|-------|-----|-----|---------------|
|
||||||
|
| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s |
|
||||||
|
| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s |
|
||||||
|
| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s |
|
||||||
|
| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s |
|
||||||
|
| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* |
|
||||||
|
| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s |
|
||||||
|
| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s |
|
||||||
|
| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s |
|
||||||
|
| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s |
|
||||||
|
| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s |
|
||||||
|
|
||||||
|
\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem.
|
||||||
|
|
||||||
|
**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Size Comparison (base vs small)
|
||||||
|
|
||||||
|
| | base | small | Observation |
|
||||||
|
|--|------|-------|-------------|
|
||||||
|
| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower |
|
||||||
|
| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base |
|
||||||
|
| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio |
|
||||||
|
| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo |
|
||||||
|
| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps |
|
||||||
|
|
||||||
|
In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Findings
|
||||||
|
|
||||||
|
### Speed (RTF = processing time / audio duration, lower is better)
|
||||||
|
|
||||||
|
1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds.
|
||||||
|
2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed.
|
||||||
|
3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time.
|
||||||
|
4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead.
|
||||||
|
5. The **small** model is 2-3x slower than base across all backends.
|
||||||
|
|
||||||
|
### Accuracy (WER = Word Error Rate, lower is better)
|
||||||
|
|
||||||
|
1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%.
|
||||||
|
2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments.
|
||||||
|
3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run.
|
||||||
|
4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER.
|
||||||
|
|
||||||
|
### Timestamps (MAE = Mean Absolute Error on word start times)
|
||||||
|
|
||||||
|
1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE).
|
||||||
|
2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications.
|
||||||
|
3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`.
|
||||||
|
4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file).
|
||||||
|
|
||||||
|
### VAC (Voice Activity Classification) Impact
|
||||||
|
|
||||||
|
| Backend | Policy | VAC | 7s English WER | 30s English WER |
|
||||||
|
|---------|--------|-----|----------------|-----------------|
|
||||||
|
| faster-whisper | LocalAgreement | on | 21.1% | 44.7% |
|
||||||
|
| faster-whisper | LocalAgreement | off | 100.0% | 100.0% |
|
||||||
|
| voxtral-mlx | voxtral | on | 0.0% | 9.2% |
|
||||||
|
| voxtral-mlx | voxtral | off | 0.0% | 9.2% |
|
||||||
|
|
||||||
|
- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output.
|
||||||
|
- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Recommendations
|
||||||
|
|
||||||
|
| Use Case | Backend | Policy | Model | Notes |
|
||||||
|
|----------|---------|--------|-------|-------|
|
||||||
|
| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER |
|
||||||
|
| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER |
|
||||||
|
| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast |
|
||||||
|
| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF |
|
||||||
|
| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles |
|
||||||
|
| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Caveats
|
||||||
|
|
||||||
|
- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions.
|
||||||
|
- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine.
|
||||||
|
- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Reproducing These Benchmarks
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test dependencies
|
||||||
|
pip install -e ".[test]"
|
||||||
|
|
||||||
|
# Single backend test
|
||||||
|
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime
|
||||||
|
|
||||||
|
# With a specific language
|
||||||
|
python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime
|
||||||
|
|
||||||
|
# Multi-backend auto-detect benchmark
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime
|
||||||
|
|
||||||
|
# Export to JSON
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||||
|
|
||||||
|
# Test with your own audio
|
||||||
|
python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmark harness computes WER and timestamp accuracy automatically when ground truth
|
||||||
|
`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Help Us Benchmark on More Hardware
|
||||||
|
|
||||||
|
These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc.
|
||||||
|
|
||||||
|
If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get.
|
||||||
|
|
||||||
|
What we are especially interested in:
|
||||||
|
- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper
|
||||||
|
- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx
|
||||||
|
- **Medium and large-v3 models** (we only tested base and small so far)
|
||||||
|
- **Longer audio files** or domain-specific audio (medical, legal, call center)
|
||||||
|
- **Other languages** beyond English and French
|
||||||
120
Dockerfile
120
Dockerfile
@@ -1,83 +1,75 @@
|
|||||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||||
|
|
||||||
|
# --- MARK: Builder Stage
|
||||||
|
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ARG EXTRAS
|
RUN apt-get update && \
|
||||||
ARG HF_PRECACHE_DIR
|
apt-get install -y --no-install-recommends \
|
||||||
ARG HF_TKN_FILE
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install UV and set up the environment
|
||||||
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||||
|
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||||
|
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||||
|
|
||||||
|
RUN uv python install 3.12
|
||||||
|
|
||||||
|
# Install dependencies first to leverage caching
|
||||||
|
ARG EXTRAS=cu129
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# Copy the source code and install the package only
|
||||||
|
COPY whisperlivekit /app/whisperlivekit
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-editable --no-cache "$@"
|
||||||
|
|
||||||
|
# --- MARK: Runtime Stage
|
||||||
|
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
python3 \
|
ffmpeg &&\
|
||||||
python3-pip \
|
|
||||||
python3-venv \
|
|
||||||
ffmpeg \
|
|
||||||
git \
|
|
||||||
build-essential \
|
|
||||||
python3-dev \
|
|
||||||
ca-certificates && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN python3 -m venv /opt/venv
|
# Copy UV binaries
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
# timeout/retries for large torch wheels
|
# Copy the Python version
|
||||||
RUN pip3 install --upgrade pip setuptools wheel && \
|
COPY --from=builder-gpu --chown=python:python /python /python
|
||||||
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
|
||||||
--index-url https://download.pytorch.org/whl/cu129 \
|
|
||||||
torch torchaudio \
|
|
||||||
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
|
||||||
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
|
||||||
--index-url https://download.pytorch.org/whl/cu129 \
|
|
||||||
torch torchvision torchaudio)
|
|
||||||
|
|
||||||
COPY . .
|
# Copy the virtual environment with all dependencies installed
|
||||||
|
COPY --from=builder-gpu /app/.venv /app/.venv
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
|
||||||
else \
|
|
||||||
echo "Installing base package only"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# In-container caching for Hugging Face models by:
|
|
||||||
# A) Make the cache directory persistent via an anonymous volume.
|
|
||||||
# Note: This only persists for a single, named container. This is
|
|
||||||
# only for convenience at de/test stage.
|
|
||||||
# For prod, it is better to use a named volume via host mount/k8s.
|
|
||||||
VOLUME ["/root/.cache/huggingface/hub"]
|
|
||||||
|
|
||||||
|
|
||||||
# or
|
|
||||||
# B) Conditionally copy a local pre-cache from the build context to the
|
|
||||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
|
||||||
# WARNING: This will copy ALL files in the pre-cache location.
|
|
||||||
|
|
||||||
# Conditionally copy a cache directory if provided
|
|
||||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
|
||||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
|
||||||
mkdir -p /root/.cache/huggingface/hub && \
|
|
||||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
|
||||||
else \
|
|
||||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
|
||||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
|
||||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
|
||||||
mkdir -p /root/.cache/huggingface && \
|
|
||||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
|
||||||
else \
|
|
||||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||||
|
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
CMD ["--model", "medium"]
|
CMD ["--model", "medium"]
|
||||||
|
|||||||
@@ -1,60 +1,75 @@
|
|||||||
FROM python:3.13-slim
|
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||||
|
|
||||||
|
# --- MARK: Builder Stage
|
||||||
|
FROM debian:bookworm-slim AS builder-cpu
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ARG EXTRAS
|
|
||||||
ARG HF_PRECACHE_DIR
|
|
||||||
ARG HF_TKN_FILE
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
ffmpeg \
|
|
||||||
git \
|
|
||||||
build-essential \
|
build-essential \
|
||||||
python3-dev && \
|
python3-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install CPU-only PyTorch
|
# Install UV and set up the environment
|
||||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
COPY . .
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||||
|
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||||
|
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||||
|
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
RUN uv python install 3.12
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
|
||||||
else \
|
|
||||||
echo "Installing base package only"; \
|
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Enable in-container caching for Hugging Face models
|
# Install dependencies first to leverage caching
|
||||||
VOLUME ["/root/.cache/huggingface/hub"]
|
ARG EXTRAS=cpu
|
||||||
|
COPY pyproject.toml uv.lock /app/
|
||||||
|
RUN set -eux; \
|
||||||
|
set --; \
|
||||||
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
|
set -- "$@" --extra "$extra"; \
|
||||||
|
done; \
|
||||||
|
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||||
|
|
||||||
# Conditionally copy a local pre-cache from the build context
|
# Copy the source code and install the package only
|
||||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
COPY whisperlivekit /app/whisperlivekit
|
||||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
RUN set -eux; \
|
||||||
mkdir -p /root/.cache/huggingface/hub && \
|
set --; \
|
||||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||||
else \
|
set -- "$@" --extra "$extra"; \
|
||||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
done; \
|
||||||
fi
|
uv sync --frozen --no-editable --no-cache "$@"
|
||||||
|
|
||||||
# Conditionally copy a Hugging Face token if provided
|
# --- MARK: Runtime Stage
|
||||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
FROM debian:bookworm-slim
|
||||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
|
||||||
mkdir -p /root/.cache/huggingface && \
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
|
||||||
else \
|
WORKDIR /app
|
||||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
|
||||||
fi
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
ffmpeg &&\
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy UV binaries
|
||||||
|
COPY --from=uvbin /uv /uvx /bin/
|
||||||
|
|
||||||
|
# Copy the Python version
|
||||||
|
COPY --from=builder-cpu --chown=python:python /python /python
|
||||||
|
|
||||||
|
# Copy the virtual environment with all dependencies installed
|
||||||
|
COPY --from=builder-cpu /app/.venv /app/.venv
|
||||||
|
|
||||||
# Expose port for the transcription server
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||||
|
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
# Default args - you might want to use a smaller model for CPU
|
# Default args - you might want to use a smaller model for CPU
|
||||||
|
|||||||
139
README.md
139
README.md
@@ -1,28 +1,32 @@
|
|||||||
<h1 align="center">WhisperLiveKit</h1>
|
<h1 align="center">WLK</h1>
|
||||||
|
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p>
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||||
|
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
|
||||||
|
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
|
||||||
|
</a>
|
||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
### Powered by Leading Research:
|
||||||
|
|
||||||
#### Powered by Leading Research:
|
**See the interactive playground in [this repo](https://github.com/QuentinFuxa/streamlit-d3-network) to explore how AlignAtt works**
|
||||||
|
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408).
|
||||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
|
||||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||||
|
- [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) (2025) - 4B-parameter multilingual speech model by Mistral AI
|
||||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||||
|
|
||||||
|
|
||||||
@@ -69,18 +73,60 @@ Go to `chrome-extension` for instructions.
|
|||||||
|
|
||||||
#### Optional Dependencies
|
#### Optional Dependencies
|
||||||
|
|
||||||
| Optional | `pip install` |
|
| Feature | `uv sync` | `pip install -e` |
|
||||||
|-----------|-------------|
|
|-----------|-------------|-------------|
|
||||||
| **Windows/Linux optimizations** | `faster-whisper` |
|
| **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
|
||||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
| **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
|
||||||
| **Translation** | `nllw` |
|
| **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
|
||||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
| **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
|
||||||
| OpenAI API | `openai` |
|
| **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
|
||||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
| **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
|
||||||
|
| **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
|
||||||
|
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
|
||||||
|
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
|
||||||
|
|
||||||
|
Supported GPU profiles:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Profile A: Sortformer diarization
|
||||||
|
uv sync --extra cu129 --extra diarization-sortformer
|
||||||
|
|
||||||
|
# Profile B: Voxtral HF + translation
|
||||||
|
uv sync --extra cu129 --extra voxtral-hf --extra translation
|
||||||
|
```
|
||||||
|
|
||||||
|
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
|
||||||
|
|
||||||
See **Parameters & Configuration** below on how to use them.
|
See **Parameters & Configuration** below on how to use them.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more.
|
||||||
|
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Voxtral Backend
|
||||||
|
|
||||||
|
WhisperLiveKit supports [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602),
|
||||||
|
a 4B-parameter speech model from Mistral AI that natively handles 100+ languages with automatic
|
||||||
|
language detection. Whisper also supports auto-detection (`--language auto`), but Voxtral's per-chunk
|
||||||
|
detection is more reliable and does not bias towards English.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Apple Silicon (native MLX, recommended)
|
||||||
|
pip install -e ".[voxtral-mlx]"
|
||||||
|
wlk --backend voxtral-mlx
|
||||||
|
|
||||||
|
# Linux/GPU (HuggingFace transformers)
|
||||||
|
pip install transformers torch
|
||||||
|
wlk --backend voxtral
|
||||||
|
```
|
||||||
|
|
||||||
|
Voxtral uses its own streaming policy and does not use LocalAgreement or SimulStreaming.
|
||||||
|
See [BENCHMARK.md](BENCHMARK.md) for performance numbers.
|
||||||
|
|
||||||
### Usage Examples
|
### Usage Examples
|
||||||
|
|
||||||
@@ -92,6 +138,9 @@ wlk --model large-v3 --language fr --target-language da
|
|||||||
|
|
||||||
# Diarization and server listening on */80
|
# Diarization and server listening on */80
|
||||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||||
|
|
||||||
|
# Voxtral multilingual (auto-detects language)
|
||||||
|
wlk --backend voxtral-mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -143,13 +192,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
||||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
||||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
| `--backend` | ASR backend selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. Options: `mlx-whisper`, `faster-whisper`, `whisper`, `openai-api` (LocalAgreement only), `voxtral-mlx` (Apple Silicon), `voxtral` (HuggingFace) | `auto` |
|
||||||
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
|
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
|
||||||
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
|
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
@@ -159,6 +208,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||||
|
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
|
||||||
|
|
||||||
| Translation options | Description | Default |
|
| Translation options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
@@ -168,7 +218,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| Diarization options | Description | Default |
|
| Diarization options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
| `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
|
||||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
@@ -186,7 +236,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| `--never-fire` | Never truncate incomplete words | `False` |
|
| `--never-fire` | Never truncate incomplete words | `False` |
|
||||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
| `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -245,7 +295,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk
|
|||||||
|
|
||||||
**CPU only:**
|
**CPU only:**
|
||||||
```bash
|
```bash
|
||||||
docker build -f Dockerfile.cpu -t wlk .
|
docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
|
||||||
docker run -p 8000:8000 --name wlk wlk
|
docker run -p 8000:8000 --name wlk wlk
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -257,6 +307,18 @@ docker run -p 8000:8000 --name wlk wlk
|
|||||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Compose (recommended for cache + token wiring):**
|
||||||
|
```bash
|
||||||
|
# GPU Sortformer profile
|
||||||
|
docker compose up --build wlk-gpu-sortformer
|
||||||
|
|
||||||
|
# GPU Voxtral profile
|
||||||
|
docker compose up --build wlk-gpu-voxtral
|
||||||
|
|
||||||
|
# CPU service
|
||||||
|
docker compose up --build wlk-cpu
|
||||||
|
```
|
||||||
|
|
||||||
### Memory Requirements
|
### Memory Requirements
|
||||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||||
|
|
||||||
@@ -264,9 +326,34 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
|||||||
#### Customization
|
#### Customization
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
- `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
|
||||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
- `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
|
||||||
|
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
|
||||||
|
|
||||||
## 🔮 Use Cases
|
## Testing & Benchmarks
|
||||||
|
|
||||||
|
WhisperLiveKit includes a unit test suite and an offline benchmark harness.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test dependencies
|
||||||
|
pip install -e ".[test]"
|
||||||
|
|
||||||
|
# Run unit tests (no model download required)
|
||||||
|
pytest tests/ -v
|
||||||
|
|
||||||
|
# Benchmark a single backend
|
||||||
|
python test_backend_offline.py --backend faster-whisper --no-realtime
|
||||||
|
|
||||||
|
# Benchmark all installed backends
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime
|
||||||
|
|
||||||
|
# Export benchmark results as JSON
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and
|
||||||
|
timestamp accuracy on Apple Silicon.
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||||
|
|||||||
BIN
architecture.png
BIN
architecture.png
Binary file not shown.
|
Before Width: | Height: | Size: 406 KiB After Width: | Height: | Size: 422 KiB |
97
audio_tests/00_00_07_english_1_speaker.transcript.json
Normal file
97
audio_tests/00_00_07_english_1_speaker.transcript.json
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"word": "This",
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 0.24
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "is",
|
||||||
|
"start": 0.24,
|
||||||
|
"end": 0.56
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "a",
|
||||||
|
"start": 0.56,
|
||||||
|
"end": 0.76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "transcription",
|
||||||
|
"start": 0.76,
|
||||||
|
"end": 1.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "test.",
|
||||||
|
"start": 1.32,
|
||||||
|
"end": 2.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "We",
|
||||||
|
"start": 2.4,
|
||||||
|
"end": 2.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "want",
|
||||||
|
"start": 2.5,
|
||||||
|
"end": 2.66
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "to",
|
||||||
|
"start": 2.66,
|
||||||
|
"end": 2.84
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "see",
|
||||||
|
"start": 2.84,
|
||||||
|
"end": 3.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "if",
|
||||||
|
"start": 3.1,
|
||||||
|
"end": 3.34
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "we",
|
||||||
|
"start": 3.34,
|
||||||
|
"end": 3.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "can",
|
||||||
|
"start": 3.5,
|
||||||
|
"end": 3.68
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "use",
|
||||||
|
"start": 3.68,
|
||||||
|
"end": 4.04
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "smaller",
|
||||||
|
"start": 4.04,
|
||||||
|
"end": 4.76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "chunks.",
|
||||||
|
"start": 4.76,
|
||||||
|
"end": 5.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "What",
|
||||||
|
"start": 6.06,
|
||||||
|
"end": 6.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "do",
|
||||||
|
"start": 6.32,
|
||||||
|
"end": 6.44
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "you",
|
||||||
|
"start": 6.44,
|
||||||
|
"end": 6.58
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "think?",
|
||||||
|
"start": 6.58,
|
||||||
|
"end": 6.84
|
||||||
|
}
|
||||||
|
]
|
||||||
177
audio_tests/00_00_16_french_1_speaker.transcript.json
Normal file
177
audio_tests/00_00_16_french_1_speaker.transcript.json
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"word": "Ok,",
|
||||||
|
"start": 2.02,
|
||||||
|
"end": 2.38
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "là",
|
||||||
|
"start": 2.52,
|
||||||
|
"end": 2.58
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "c",
|
||||||
|
"start": 2.58,
|
||||||
|
"end": 2.74
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "'est",
|
||||||
|
"start": 2.74,
|
||||||
|
"end": 2.76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "un",
|
||||||
|
"start": 2.76,
|
||||||
|
"end": 2.86
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "test,",
|
||||||
|
"start": 2.86,
|
||||||
|
"end": 3.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "on",
|
||||||
|
"start": 3.34,
|
||||||
|
"end": 3.34
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "veut",
|
||||||
|
"start": 3.34,
|
||||||
|
"end": 3.48
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "voir",
|
||||||
|
"start": 3.48,
|
||||||
|
"end": 3.86
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "si",
|
||||||
|
"start": 3.86,
|
||||||
|
"end": 4.14
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "ça",
|
||||||
|
"start": 4.14,
|
||||||
|
"end": 4.26
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "arrive",
|
||||||
|
"start": 4.26,
|
||||||
|
"end": 4.36
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "à",
|
||||||
|
"start": 4.36,
|
||||||
|
"end": 4.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "capté",
|
||||||
|
"start": 4.5,
|
||||||
|
"end": 4.78
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "le",
|
||||||
|
"start": 4.78,
|
||||||
|
"end": 4.9
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "silence.",
|
||||||
|
"start": 4.9,
|
||||||
|
"end": 5.44
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Là",
|
||||||
|
"start": 9.24,
|
||||||
|
"end": 9.6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "il",
|
||||||
|
"start": 9.6,
|
||||||
|
"end": 9.78
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "est",
|
||||||
|
"start": 9.78,
|
||||||
|
"end": 9.84
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "une",
|
||||||
|
"start": 9.84,
|
||||||
|
"end": 9.96
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "telle",
|
||||||
|
"start": 9.96,
|
||||||
|
"end": 10.12
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "seconde",
|
||||||
|
"start": 10.12,
|
||||||
|
"end": 10.38
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "de",
|
||||||
|
"start": 10.38,
|
||||||
|
"end": 10.48
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "silence",
|
||||||
|
"start": 10.48,
|
||||||
|
"end": 10.78
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "et",
|
||||||
|
"start": 10.78,
|
||||||
|
"end": 11.06
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "je",
|
||||||
|
"start": 11.06,
|
||||||
|
"end": 11.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "vous",
|
||||||
|
"start": 11.16,
|
||||||
|
"end": 11.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "parle.",
|
||||||
|
"start": 11.32,
|
||||||
|
"end": 11.68
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Et",
|
||||||
|
"start": 13.28,
|
||||||
|
"end": 13.64
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "voilà,",
|
||||||
|
"start": 13.64,
|
||||||
|
"end": 13.96
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "allez",
|
||||||
|
"start": 14.36,
|
||||||
|
"end": 14.62
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "on",
|
||||||
|
"start": 14.62,
|
||||||
|
"end": 14.78
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "va",
|
||||||
|
"start": 14.78,
|
||||||
|
"end": 14.88
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "tester",
|
||||||
|
"start": 14.88,
|
||||||
|
"end": 15.06
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "ça.",
|
||||||
|
"start": 15.06,
|
||||||
|
"end": 15.36
|
||||||
|
}
|
||||||
|
]
|
||||||
382
audio_tests/00_00_30_english_3_speakers.transcript.json
Normal file
382
audio_tests/00_00_30_english_3_speakers.transcript.json
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"word": "Transcription",
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 0.6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "technology",
|
||||||
|
"start": 0.6,
|
||||||
|
"end": 1.24
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "has",
|
||||||
|
"start": 1.24,
|
||||||
|
"end": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "improved",
|
||||||
|
"start": 1.5,
|
||||||
|
"end": 1.96
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "so",
|
||||||
|
"start": 1.96,
|
||||||
|
"end": 2.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "much",
|
||||||
|
"start": 2.32,
|
||||||
|
"end": 2.68
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "in",
|
||||||
|
"start": 2.68,
|
||||||
|
"end": 2.94
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "the",
|
||||||
|
"start": 2.94,
|
||||||
|
"end": 3.02
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "past",
|
||||||
|
"start": 3.02,
|
||||||
|
"end": 3.24
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "few",
|
||||||
|
"start": 3.24,
|
||||||
|
"end": 3.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "years.",
|
||||||
|
"start": 3.5,
|
||||||
|
"end": 3.96
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Have",
|
||||||
|
"start": 4.56,
|
||||||
|
"end": 4.74
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "you",
|
||||||
|
"start": 4.74,
|
||||||
|
"end": 4.9
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "noticed",
|
||||||
|
"start": 4.9,
|
||||||
|
"end": 5.26
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "how",
|
||||||
|
"start": 5.26,
|
||||||
|
"end": 5.52
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "accurate",
|
||||||
|
"start": 5.52,
|
||||||
|
"end": 6.08
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "real",
|
||||||
|
"start": 6.08,
|
||||||
|
"end": 6.42
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "-time",
|
||||||
|
"start": 6.42,
|
||||||
|
"end": 6.74
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "speech",
|
||||||
|
"start": 6.74,
|
||||||
|
"end": 7.24
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "to",
|
||||||
|
"start": 7.24,
|
||||||
|
"end": 7.46
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "text",
|
||||||
|
"start": 7.46,
|
||||||
|
"end": 7.78
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "is",
|
||||||
|
"start": 7.78,
|
||||||
|
"end": 8.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "now?",
|
||||||
|
"start": 8.0,
|
||||||
|
"end": 8.3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Absolutely.",
|
||||||
|
"start": 8.7,
|
||||||
|
"end": 9.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "I",
|
||||||
|
"start": 10.04,
|
||||||
|
"end": 10.38
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "use",
|
||||||
|
"start": 10.38,
|
||||||
|
"end": 10.56
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "it",
|
||||||
|
"start": 10.56,
|
||||||
|
"end": 10.76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "all",
|
||||||
|
"start": 10.76,
|
||||||
|
"end": 10.9
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "the",
|
||||||
|
"start": 10.9,
|
||||||
|
"end": 11.04
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "time",
|
||||||
|
"start": 11.04,
|
||||||
|
"end": 11.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "for",
|
||||||
|
"start": 11.32,
|
||||||
|
"end": 11.54
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "taking",
|
||||||
|
"start": 11.54,
|
||||||
|
"end": 11.86
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "notes",
|
||||||
|
"start": 11.86,
|
||||||
|
"end": 12.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "during",
|
||||||
|
"start": 12.16,
|
||||||
|
"end": 12.54
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "meetings.",
|
||||||
|
"start": 12.54,
|
||||||
|
"end": 12.94
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "It's",
|
||||||
|
"start": 13.6,
|
||||||
|
"end": 13.8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "amazing",
|
||||||
|
"start": 13.8,
|
||||||
|
"end": 14.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "how",
|
||||||
|
"start": 14.1,
|
||||||
|
"end": 14.48
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "it",
|
||||||
|
"start": 14.48,
|
||||||
|
"end": 14.62
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "can",
|
||||||
|
"start": 14.62,
|
||||||
|
"end": 14.74
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "recognise",
|
||||||
|
"start": 14.74,
|
||||||
|
"end": 15.24
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "different",
|
||||||
|
"start": 15.24,
|
||||||
|
"end": 15.68
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "speakers",
|
||||||
|
"start": 15.68,
|
||||||
|
"end": 16.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "and",
|
||||||
|
"start": 16.16,
|
||||||
|
"end": 16.8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "even",
|
||||||
|
"start": 16.8,
|
||||||
|
"end": 17.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "add",
|
||||||
|
"start": 17.1,
|
||||||
|
"end": 17.44
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "punctuation.",
|
||||||
|
"start": 17.44,
|
||||||
|
"end": 18.36
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Yeah,",
|
||||||
|
"start": 18.88,
|
||||||
|
"end": 19.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "but",
|
||||||
|
"start": 19.36,
|
||||||
|
"end": 19.52
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "sometimes",
|
||||||
|
"start": 19.52,
|
||||||
|
"end": 20.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "noise",
|
||||||
|
"start": 20.16,
|
||||||
|
"end": 20.54
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "can",
|
||||||
|
"start": 20.54,
|
||||||
|
"end": 20.8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "still",
|
||||||
|
"start": 20.8,
|
||||||
|
"end": 21.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "cause",
|
||||||
|
"start": 21.1,
|
||||||
|
"end": 21.44
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "mistakes.",
|
||||||
|
"start": 21.44,
|
||||||
|
"end": 21.94
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "Does",
|
||||||
|
"start": 22.68,
|
||||||
|
"end": 22.9
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "this",
|
||||||
|
"start": 22.9,
|
||||||
|
"end": 23.12
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "system",
|
||||||
|
"start": 23.12,
|
||||||
|
"end": 23.46
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "handle",
|
||||||
|
"start": 23.46,
|
||||||
|
"end": 23.88
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "that",
|
||||||
|
"start": 23.88,
|
||||||
|
"end": 24.12
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "well?",
|
||||||
|
"start": 24.12,
|
||||||
|
"end": 24.42
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "It",
|
||||||
|
"start": 24.42,
|
||||||
|
"end": 25.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "does",
|
||||||
|
"start": 25.32,
|
||||||
|
"end": 25.48
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "a",
|
||||||
|
"start": 25.48,
|
||||||
|
"end": 25.62
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "pretty",
|
||||||
|
"start": 25.62,
|
||||||
|
"end": 25.88
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "good",
|
||||||
|
"start": 25.88,
|
||||||
|
"end": 26.08
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "job",
|
||||||
|
"start": 26.08,
|
||||||
|
"end": 26.32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "filtering",
|
||||||
|
"start": 26.32,
|
||||||
|
"end": 26.8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "noise,",
|
||||||
|
"start": 26.8,
|
||||||
|
"end": 27.18
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "especially",
|
||||||
|
"start": 27.36,
|
||||||
|
"end": 28.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "with",
|
||||||
|
"start": 28.0,
|
||||||
|
"end": 28.28
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "models",
|
||||||
|
"start": 28.28,
|
||||||
|
"end": 28.62
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "that",
|
||||||
|
"start": 28.62,
|
||||||
|
"end": 28.94
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "use",
|
||||||
|
"start": 28.94,
|
||||||
|
"end": 29.22
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "voice",
|
||||||
|
"start": 29.22,
|
||||||
|
"end": 29.54
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": "active.",
|
||||||
|
"start": 29.54,
|
||||||
|
"end": 29.9
|
||||||
|
}
|
||||||
|
]
|
||||||
57
audio_tests/generate_transcripts.py
Normal file
57
audio_tests/generate_transcripts.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate word-level timestamped transcripts using faster-whisper (offline).
|
||||||
|
|
||||||
|
Produces one JSON file per audio with: [{word, start, end}, ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
FILES = [
|
||||||
|
("00_00_07_english_1_speaker.wav", "en"),
|
||||||
|
("00_00_16_french_1_speaker.wav", "fr"),
|
||||||
|
("00_00_30_english_3_speakers.wav", "en"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Loading faster-whisper model (base, cpu, float32)...")
|
||||||
|
model = WhisperModel("base", device="cpu", compute_type="float32")
|
||||||
|
|
||||||
|
for filename, lang in FILES:
|
||||||
|
audio_path = os.path.join(AUDIO_DIR, filename)
|
||||||
|
out_path = os.path.join(
|
||||||
|
AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Transcribing: {filename} (language={lang})")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
segments, info = model.transcribe(
|
||||||
|
audio_path, word_timestamps=True, language=lang
|
||||||
|
)
|
||||||
|
|
||||||
|
words = []
|
||||||
|
for segment in segments:
|
||||||
|
if segment.words:
|
||||||
|
for w in segment.words:
|
||||||
|
words.append({
|
||||||
|
"word": w.word.strip(),
|
||||||
|
"start": round(w.start, 3),
|
||||||
|
"end": round(w.end, 3),
|
||||||
|
})
|
||||||
|
print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}")
|
||||||
|
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(words, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}")
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
benchmark_chart.png
Normal file
BIN
benchmark_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
BIN
benchmark_scatter.png
Normal file
BIN
benchmark_scatter.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 95 KiB |
@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
|
|||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
|
|
||||||
## Running this extension
|
## Running this extension
|
||||||
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
52
compose.yml
Normal file
52
compose.yml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
services:
|
||||||
|
wlk-gpu-sortformer:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
|
||||||
|
image: wlk:gpu-sortformer
|
||||||
|
gpus: all
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
command: ["--model", "medium", "--diarization", "--pcm-input"]
|
||||||
|
|
||||||
|
wlk-gpu-voxtral:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
|
||||||
|
image: wlk:gpu-voxtral
|
||||||
|
gpus: all
|
||||||
|
ports:
|
||||||
|
- "8001:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
command: ["--backend", "voxtral", "--pcm-input"]
|
||||||
|
|
||||||
|
wlk-cpu:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.cpu
|
||||||
|
args:
|
||||||
|
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
|
||||||
|
image: wlk:cpu
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- hf-cache:/root/.cache/huggingface/hub
|
||||||
|
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||||
|
environment:
|
||||||
|
- HF_TOKEN
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
hf-cache:
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
# Available Whisper model sizes:
|
|
||||||
|
|
||||||
- tiny.en (english only)
|
|
||||||
- tiny
|
|
||||||
- base.en (english only)
|
|
||||||
- base
|
|
||||||
- small.en (english only)
|
|
||||||
- small
|
|
||||||
- medium.en (english only)
|
|
||||||
- medium
|
|
||||||
- large-v1
|
|
||||||
- large-v2
|
|
||||||
- large-v3
|
|
||||||
- large-v3-turbo
|
|
||||||
|
|
||||||
## How to choose?
|
|
||||||
|
|
||||||
### Language Support
|
|
||||||
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
|
|
||||||
- **Multilingual**: Do not use `.en` models.
|
|
||||||
|
|
||||||
### Resource Constraints
|
|
||||||
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
|
|
||||||
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
|
|
||||||
- `base`: Good balance of speed and accuracy for basic use cases
|
|
||||||
- `small`: Better accuracy while still being resource-efficient
|
|
||||||
- **Good resources available**: Use `large` models for best accuracy
|
|
||||||
- `large-v2`: Excellent accuracy, good multilingual support
|
|
||||||
- `large-v3`: Best overall accuracy and language support
|
|
||||||
|
|
||||||
### Special Cases
|
|
||||||
- **No translation needed**: Use `large-v3-turbo`
|
|
||||||
- Same transcription quality as `large-v2` but significantly faster
|
|
||||||
- **Important**: Does not translate correctly, only transcribes
|
|
||||||
|
|
||||||
### Model Comparison Table
|
|
||||||
|
|
||||||
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|
|
||||||
|-------|--------|----------|--------------|-------------|---------------|
|
|
||||||
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
|
|
||||||
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
|
|
||||||
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
|
|
||||||
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
|
|
||||||
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
|
|
||||||
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
|
|
||||||
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
|
|
||||||
|
|
||||||
### Additional Considerations
|
|
||||||
|
|
||||||
**Model Performance**:
|
|
||||||
- Accuracy improves significantly from tiny to large models
|
|
||||||
- English-only models are ~10-15% more accurate for English audio
|
|
||||||
- Newer versions (v2, v3) have better punctuation and formatting
|
|
||||||
|
|
||||||
**Hardware Requirements**:
|
|
||||||
- `tiny`: ~1GB VRAM
|
|
||||||
- `base`: ~1GB VRAM
|
|
||||||
- `small`: ~2GB VRAM
|
|
||||||
- `medium`: ~5GB VRAM
|
|
||||||
- `large`: ~10GB VRAM
|
|
||||||
- `large‑v3‑turbo`: ~6GB VRAM
|
|
||||||
|
|
||||||
**Audio Quality Impact**:
|
|
||||||
- Clean, clear audio: smaller models may suffice
|
|
||||||
- Noisy, accented, or technical audio: larger models recommended
|
|
||||||
- Phone/low-quality audio: use at least `small` model
|
|
||||||
|
|
||||||
### Quick Decision Tree
|
|
||||||
1. English only? → Add `.en` to your choice
|
|
||||||
2. Limited resources or need speed? → `small` or smaller
|
|
||||||
3. Good hardware and want best quality? → `large-v3`
|
|
||||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
|
||||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
|
||||||
|
|
||||||
|
|
||||||
_______________________
|
|
||||||
|
|
||||||
# Translation Models and Backend
|
|
||||||
|
|
||||||
**Language Support**: ~200 languages
|
|
||||||
|
|
||||||
## Distilled Model Sizes Available
|
|
||||||
|
|
||||||
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
|
||||||
|-------|------|------------|-------------|-------------|---------|
|
|
||||||
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
|
||||||
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
|
||||||
|
|
||||||
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
|
||||||
|
|
||||||
## Backend Performance
|
|
||||||
|
|
||||||
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
|
||||||
|---------|---------------|--------------|--------------|
|
|
||||||
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
|
||||||
| Transformers | Baseline | High | None |
|
|
||||||
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
|
||||||
|
|
||||||
**Metrics**:
|
|
||||||
- CTranslate2: 50-100+ tokens/sec
|
|
||||||
- Transformers: 10-30 tokens/sec
|
|
||||||
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
|
||||||
|
|
||||||
## Quick Decision Matrix
|
|
||||||
|
|
||||||
**Choose 600M**: Limited resources, close to 0 lag
|
|
||||||
**Choose 1.3B**: Quality matters
|
|
||||||
**Choose Transformers**: On Apple Silicon
|
|
||||||
|
|
||||||
106
docs/default_and_custom_models.md
Normal file
106
docs/default_and_custom_models.md
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Models and Model Paths
|
||||||
|
|
||||||
|
## Defaults
|
||||||
|
|
||||||
|
**Default Whisper Model**: `base`
|
||||||
|
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
|
||||||
|
|
||||||
|
**Default Model Cache Directory**: `~/.cache/whisper`
|
||||||
|
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
|
||||||
|
|
||||||
|
**Default Translation Model**: `600M` (NLLB-200-distilled)
|
||||||
|
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
|
||||||
|
|
||||||
|
**Default Translation Backend**: `transformers`
|
||||||
|
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
## Available Whisper model sizes:
|
||||||
|
|
||||||
|
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|
||||||
|
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
|
||||||
|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
|
||||||
|
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
|
||||||
|
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
|
||||||
|
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
|
||||||
|
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
|
||||||
|
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
|
||||||
|
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
|
||||||
|
|
||||||
|
|
||||||
|
### How to choose?
|
||||||
|
|
||||||
|
#### Language Support
|
||||||
|
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
|
||||||
|
- **Multilingual**: Do not use `.en` models.
|
||||||
|
|
||||||
|
#### Special Cases
|
||||||
|
- **No translation needed**: Use `large-v3-turbo`
|
||||||
|
- Same transcription quality as `large-v2` but significantly faster
|
||||||
|
- **Important**: Does not translate correctly, only transcribes
|
||||||
|
|
||||||
|
### Additional Considerations
|
||||||
|
|
||||||
|
**Model Performance**:
|
||||||
|
- Accuracy improves significantly from tiny to large models
|
||||||
|
- English-only models are ~10-15% more accurate for English audio
|
||||||
|
- Newer versions (v2, v3) have better punctuation and formatting
|
||||||
|
|
||||||
|
**Audio Quality Impact**:
|
||||||
|
- Clean, clear audio: smaller models may suffice
|
||||||
|
- Noisy, accented, or technical audio: larger models recommended
|
||||||
|
- Phone/low-quality audio: use at least `small` model
|
||||||
|
|
||||||
|
_______________________
|
||||||
|
|
||||||
|
|
||||||
|
# Custom Models:
|
||||||
|
|
||||||
|
The `--model-path` parameter accepts:
|
||||||
|
|
||||||
|
## File Path
|
||||||
|
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||||
|
|
||||||
|
## Directory Path (recommended)
|
||||||
|
Must contain:
|
||||||
|
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||||
|
|
||||||
|
May optionally contain:
|
||||||
|
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||||
|
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||||
|
|
||||||
|
## Hugging Face Repo ID
|
||||||
|
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||||
|
|
||||||
|
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
|
||||||
|
|
||||||
|
|
||||||
|
_______________________
|
||||||
|
|
||||||
|
# Translation Models and Backend
|
||||||
|
|
||||||
|
**Language Support**: ~200 languages
|
||||||
|
|
||||||
|
## Distilled Model Sizes Available
|
||||||
|
|
||||||
|
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||||
|
|-------|------|------------|-------------|-------------|---------|
|
||||||
|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||||
|
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||||
|
|
||||||
|
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||||
|
|
||||||
|
## Backend Performance
|
||||||
|
|
||||||
|
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||||
|
|---------|---------------|--------------|--------------|
|
||||||
|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||||
|
| Transformers | Baseline | High | None |
|
||||||
|
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||||
|
|
||||||
|
**Metrics**:
|
||||||
|
- CTranslate2: 50-100+ tokens/sec
|
||||||
|
- Transformers: 10-30 tokens/sec
|
||||||
|
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# Model Path Formats
|
|
||||||
|
|
||||||
The `--model-path` parameter accepts:
|
|
||||||
|
|
||||||
## File Path
|
|
||||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
|
||||||
|
|
||||||
## Directory Path (recommended)
|
|
||||||
Must contain:
|
|
||||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
|
||||||
|
|
||||||
May optionally contain:
|
|
||||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
|
||||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
|
||||||
|
|
||||||
## Hugging Face Repo ID
|
|
||||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
|
||||||
|
|
||||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
|
||||||
@@ -1,6 +1,114 @@
|
|||||||
# Supported Languages
|
# Transcription: Supported Language
|
||||||
|
|
||||||
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
WLK supports transcription in the following languages:
|
||||||
|
|
||||||
|
| ISO Code | Language Name |
|
||||||
|
|----------|---------------------|
|
||||||
|
| en | English |
|
||||||
|
| zh | Chinese |
|
||||||
|
| de | German |
|
||||||
|
| es | Spanish |
|
||||||
|
| ru | Russian |
|
||||||
|
| ko | Korean |
|
||||||
|
| fr | French |
|
||||||
|
| ja | Japanese |
|
||||||
|
| pt | Portuguese |
|
||||||
|
| tr | Turkish |
|
||||||
|
| pl | Polish |
|
||||||
|
| ca | Catalan |
|
||||||
|
| nl | Dutch |
|
||||||
|
| ar | Arabic |
|
||||||
|
| sv | Swedish |
|
||||||
|
| it | Italian |
|
||||||
|
| id | Indonesian |
|
||||||
|
| hi | Hindi |
|
||||||
|
| fi | Finnish |
|
||||||
|
| vi | Vietnamese |
|
||||||
|
| he | Hebrew |
|
||||||
|
| uk | Ukrainian |
|
||||||
|
| el | Greek |
|
||||||
|
| ms | Malay |
|
||||||
|
| cs | Czech |
|
||||||
|
| ro | Romanian |
|
||||||
|
| da | Danish |
|
||||||
|
| hu | Hungarian |
|
||||||
|
| ta | Tamil |
|
||||||
|
| no | Norwegian |
|
||||||
|
| th | Thai |
|
||||||
|
| ur | Urdu |
|
||||||
|
| hr | Croatian |
|
||||||
|
| bg | Bulgarian |
|
||||||
|
| lt | Lithuanian |
|
||||||
|
| la | Latin |
|
||||||
|
| mi | Maori |
|
||||||
|
| ml | Malayalam |
|
||||||
|
| cy | Welsh |
|
||||||
|
| sk | Slovak |
|
||||||
|
| te | Telugu |
|
||||||
|
| fa | Persian |
|
||||||
|
| lv | Latvian |
|
||||||
|
| bn | Bengali |
|
||||||
|
| sr | Serbian |
|
||||||
|
| az | Azerbaijani |
|
||||||
|
| sl | Slovenian |
|
||||||
|
| kn | Kannada |
|
||||||
|
| et | Estonian |
|
||||||
|
| mk | Macedonian |
|
||||||
|
| br | Breton |
|
||||||
|
| eu | Basque |
|
||||||
|
| is | Icelandic |
|
||||||
|
| hy | Armenian |
|
||||||
|
| ne | Nepali |
|
||||||
|
| mn | Mongolian |
|
||||||
|
| bs | Bosnian |
|
||||||
|
| kk | Kazakh |
|
||||||
|
| sq | Albanian |
|
||||||
|
| sw | Swahili |
|
||||||
|
| gl | Galician |
|
||||||
|
| mr | Marathi |
|
||||||
|
| pa | Punjabi |
|
||||||
|
| si | Sinhala |
|
||||||
|
| km | Khmer |
|
||||||
|
| sn | Shona |
|
||||||
|
| yo | Yoruba |
|
||||||
|
| so | Somali |
|
||||||
|
| af | Afrikaans |
|
||||||
|
| oc | Occitan |
|
||||||
|
| ka | Georgian |
|
||||||
|
| be | Belarusian |
|
||||||
|
| tg | Tajik |
|
||||||
|
| sd | Sindhi |
|
||||||
|
| gu | Gujarati |
|
||||||
|
| am | Amharic |
|
||||||
|
| yi | Yiddish |
|
||||||
|
| lo | Lao |
|
||||||
|
| uz | Uzbek |
|
||||||
|
| fo | Faroese |
|
||||||
|
| ht | Haitian Creole |
|
||||||
|
| ps | Pashto |
|
||||||
|
| tk | Turkmen |
|
||||||
|
| nn | Nynorsk |
|
||||||
|
| mt | Maltese |
|
||||||
|
| sa | Sanskrit |
|
||||||
|
| lb | Luxembourgish |
|
||||||
|
| my | Myanmar |
|
||||||
|
| bo | Tibetan |
|
||||||
|
| tl | Tagalog |
|
||||||
|
| mg | Malagasy |
|
||||||
|
| as | Assamese |
|
||||||
|
| tt | Tatar |
|
||||||
|
| haw | Hawaiian |
|
||||||
|
| ln | Lingala |
|
||||||
|
| ha | Hausa |
|
||||||
|
| ba | Bashkir |
|
||||||
|
| jw | Javanese |
|
||||||
|
| su | Sundanese |
|
||||||
|
| yue | Cantonese |
|
||||||
|
|
||||||
|
|
||||||
|
# Translation: Supported Languages
|
||||||
|
|
||||||
|
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
|
||||||
|
|
||||||
## How to Specify Languages
|
## How to Specify Languages
|
||||||
|
|
||||||
|
|||||||
@@ -40,4 +40,4 @@ This document introduce how to reuse the core components when you do **not** wan
|
|||||||
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
||||||
|
|
||||||
|
|
||||||
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
|
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.
|
||||||
@@ -82,16 +82,43 @@ print(torch.cuda.is_available(), torch.cuda.get_device_name())
|
|||||||
```python
|
```python
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
print("CUDA devices:", ctranslate2.get_cuda_device_count())
|
print("CUDA devices:", ctranslate2.get_cuda_device_count())
|
||||||
|
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
|
||||||
|
|
||||||
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
|
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Hopper / Blackwell (`sm_121a`) systems
|
## Hopper / Blackwell (`sm_121a`) systems
|
||||||
> Reported in issue #276 (NVIDIA DGX Spark)
|
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
|
||||||
|
|
||||||
CUDA 12.1a GPUs ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual hints:
|
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
|
||||||
|
|
||||||
|
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
|
||||||
|
|
||||||
|
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Find your Python environment's Triton directory
|
||||||
|
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
|
||||||
|
|
||||||
|
# Copy the system ptxas to Triton's backend directory
|
||||||
|
# Replace <triton_path> with the output above
|
||||||
|
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, in a virtual environment:
|
||||||
|
```bash
|
||||||
|
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
|
||||||
|
|
||||||
|
### Alternative: Environment variable approach
|
||||||
|
|
||||||
|
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export CUDA_HOME="/usr/local/cuda-13.0"
|
export CUDA_HOME="/usr/local/cuda-13.0"
|
||||||
@@ -105,7 +132,7 @@ export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
|
|||||||
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
|
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
|
||||||
```
|
```
|
||||||
|
|
||||||
After exporting those variables (or adding them to your systemd service / shell profile), restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
|
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -4,27 +4,21 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.15"
|
version = "0.2.19"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [{ name = "Quentin Fuxa" }]
|
||||||
{ name = "Quentin Fuxa" }
|
|
||||||
]
|
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.11, <3.14"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
"Programming Language :: Python :: 3.14",
|
|
||||||
"Programming Language :: Python :: 3.15",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi",
|
"fastapi",
|
||||||
@@ -32,17 +26,91 @@ dependencies = [
|
|||||||
"soundfile",
|
"soundfile",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"websockets",
|
"websockets",
|
||||||
"torchaudio>=2.0.0",
|
|
||||||
"torch>=2.0.0",
|
|
||||||
"huggingface-hub>=0.25.0",
|
"huggingface-hub>=0.25.0",
|
||||||
|
"faster-whisper>=1.2.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"torchaudio>=2.0.0",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
|
||||||
translation = ["nllw"]
|
translation = ["nllw"]
|
||||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||||
|
mlx-whisper = [
|
||||||
|
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
]
|
||||||
|
voxtral-mlx = [
|
||||||
|
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||||
|
"mistral-common[audio]",
|
||||||
|
]
|
||||||
|
voxtral-hf = [
|
||||||
|
"transformers>=5.2.0; python_version >= '3.10'",
|
||||||
|
"mistral-common[audio]",
|
||||||
|
"accelerate>=0.12",
|
||||||
|
]
|
||||||
|
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
|
||||||
|
cu129 = [
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"torchaudio>=2.0.0",
|
||||||
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
|
||||||
|
]
|
||||||
|
diarization-sortformer = [
|
||||||
|
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
|
||||||
|
]
|
||||||
|
diarization-diart = [
|
||||||
|
"diart",
|
||||||
|
"torch<2.9.0",
|
||||||
|
"torchaudio<2.9.0",
|
||||||
|
"torchvision<0.24.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = ["rich>=14.3.3"]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
conflicts = [
|
||||||
|
[
|
||||||
|
{ extra = "cpu" },
|
||||||
|
{ extra = "cu129" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "diarization-diart" },
|
||||||
|
{ extra = "cu129" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ extra = "voxtral-hf" },
|
||||||
|
{ extra = "diarization-sortformer" },
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||||
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||||
|
]
|
||||||
|
torchvision = [
|
||||||
|
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu129"
|
||||||
|
url = "https://download.pytorch.org/whl/cu129"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
@@ -56,15 +124,18 @@ packages = [
|
|||||||
"whisperlivekit",
|
"whisperlivekit",
|
||||||
"whisperlivekit.diarization",
|
"whisperlivekit.diarization",
|
||||||
"whisperlivekit.simul_whisper",
|
"whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.simul_whisper.mlx",
|
||||||
"whisperlivekit.whisper",
|
"whisperlivekit.whisper",
|
||||||
"whisperlivekit.whisper.assets",
|
"whisperlivekit.whisper.assets",
|
||||||
"whisperlivekit.whisper.normalizers",
|
"whisperlivekit.whisper.normalizers",
|
||||||
"whisperlivekit.web",
|
"whisperlivekit.web",
|
||||||
"whisperlivekit.local_agreement",
|
"whisperlivekit.local_agreement",
|
||||||
"whisperlivekit.silero_vad_models"
|
"whisperlivekit.voxtral_mlx",
|
||||||
|
"whisperlivekit.silero_vad_models",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
|
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||||
|
|||||||
291
run_benchmark.py
Normal file
291
run_benchmark.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Comprehensive benchmark runner for WhisperLiveKit.
|
||||||
|
|
||||||
|
Tests all available backend+policy combinations across multiple audio files,
|
||||||
|
model sizes, and VAC on/off configurations. Outputs structured JSON that
|
||||||
|
is consumed by the report generator.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python run_benchmark.py # full benchmark
|
||||||
|
python run_benchmark.py --quick # subset (tiny models, fewer combos)
|
||||||
|
python run_benchmark.py --json results.json # custom output path
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
logger = logging.getLogger("benchmark")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Re-use harness functions
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from test_backend_offline import (
|
||||||
|
AUDIO_TESTS_DIR,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
TestResult,
|
||||||
|
create_engine,
|
||||||
|
discover_audio_files,
|
||||||
|
download_sample_audio,
|
||||||
|
load_audio,
|
||||||
|
run_test,
|
||||||
|
)
|
||||||
|
|
||||||
|
CACHE_DIR = Path(__file__).parent / ".test_cache"
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_info() -> dict:
|
||||||
|
"""Collect system metadata for the report."""
|
||||||
|
info = {
|
||||||
|
"platform": platform.platform(),
|
||||||
|
"machine": platform.machine(),
|
||||||
|
"processor": platform.processor(),
|
||||||
|
"python_version": platform.python_version(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# macOS: get chip info
|
||||||
|
try:
|
||||||
|
chip = subprocess.check_output(
|
||||||
|
["sysctl", "-n", "machdep.cpu.brand_string"], text=True
|
||||||
|
).strip()
|
||||||
|
info["cpu"] = chip
|
||||||
|
except Exception:
|
||||||
|
info["cpu"] = platform.processor()
|
||||||
|
|
||||||
|
# RAM
|
||||||
|
try:
|
||||||
|
mem_bytes = int(
|
||||||
|
subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()
|
||||||
|
)
|
||||||
|
info["ram_gb"] = round(mem_bytes / (1024**3))
|
||||||
|
except Exception:
|
||||||
|
info["ram_gb"] = None
|
||||||
|
|
||||||
|
# Backend versions
|
||||||
|
versions = {}
|
||||||
|
try:
|
||||||
|
import faster_whisper
|
||||||
|
versions["faster-whisper"] = faster_whisper.__version__
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import mlx_whisper # noqa: F401
|
||||||
|
versions["mlx-whisper"] = "installed"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
versions["mlx"] = mx.__version__
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
versions["transformers"] = transformers.__version__
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
versions["torch"] = torch.__version__
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
info["backend_versions"] = versions
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def detect_combos(quick: bool = False) -> list:
|
||||||
|
"""Build list of (backend, policy, model_size) combos to test."""
|
||||||
|
combos = []
|
||||||
|
|
||||||
|
# Model sizes to test
|
||||||
|
model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"]
|
||||||
|
|
||||||
|
# faster-whisper
|
||||||
|
try:
|
||||||
|
import faster_whisper # noqa: F401
|
||||||
|
for model in model_sizes:
|
||||||
|
combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model})
|
||||||
|
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# mlx-whisper
|
||||||
|
try:
|
||||||
|
import mlx_whisper # noqa: F401
|
||||||
|
for model in model_sizes:
|
||||||
|
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model})
|
||||||
|
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# voxtral-mlx (single model, single policy)
|
||||||
|
try:
|
||||||
|
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
|
||||||
|
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# voxtral HF (single model, single policy)
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
|
||||||
|
combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return combos
|
||||||
|
|
||||||
|
|
||||||
|
def collect_audio_files() -> list:
|
||||||
|
"""Collect all benchmark audio files."""
|
||||||
|
files = []
|
||||||
|
|
||||||
|
# audio_tests/ directory
|
||||||
|
if AUDIO_TESTS_DIR.is_dir():
|
||||||
|
files.extend(discover_audio_files(str(AUDIO_TESTS_DIR)))
|
||||||
|
|
||||||
|
# JFK sample
|
||||||
|
jfk = CACHE_DIR / "jfk.wav"
|
||||||
|
if not jfk.exists():
|
||||||
|
jfk = download_sample_audio()
|
||||||
|
if jfk.exists():
|
||||||
|
files.append(jfk)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_combo(
|
||||||
|
combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float,
|
||||||
|
) -> list:
|
||||||
|
"""Run one backend+policy+model combo across all audio files."""
|
||||||
|
backend = combo["backend"]
|
||||||
|
policy = combo["policy"]
|
||||||
|
model = combo["model"]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
engine = create_engine(
|
||||||
|
backend=backend,
|
||||||
|
model_size=model,
|
||||||
|
lan=lan,
|
||||||
|
vac=vac,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quiet noisy loggers
|
||||||
|
for mod in (
|
||||||
|
"whisperlivekit.audio_processor",
|
||||||
|
"whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.tokens_alignment",
|
||||||
|
"whisperlivekit.simul_whisper.align_att_base",
|
||||||
|
"whisperlivekit.simul_whisper.simul_whisper",
|
||||||
|
):
|
||||||
|
logging.getLogger(mod).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
for audio_path in audio_files:
|
||||||
|
duration = len(load_audio(str(audio_path))) / SAMPLE_RATE
|
||||||
|
if duration > max_duration:
|
||||||
|
logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_lan = lan
|
||||||
|
if "french" in audio_path.name.lower() and lan == "en":
|
||||||
|
file_lan = "fr"
|
||||||
|
|
||||||
|
audio = load_audio(str(audio_path))
|
||||||
|
result = await run_test(
|
||||||
|
engine, audio, chunk_ms=100, realtime=False,
|
||||||
|
audio_file=audio_path.name, backend=backend,
|
||||||
|
policy=policy, lan=file_lan,
|
||||||
|
)
|
||||||
|
# Tag with extra metadata
|
||||||
|
result_dict = asdict(result)
|
||||||
|
result_dict["model_size"] = model
|
||||||
|
result_dict["vac"] = vac
|
||||||
|
results.append(result_dict)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f" FAILED: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def run_full_benchmark(combos, audio_files, max_duration=60.0):
|
||||||
|
"""Run all combos with VAC on and off."""
|
||||||
|
all_results = []
|
||||||
|
total = len(combos) * 2 # x2 for VAC on/off
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for combo in combos:
|
||||||
|
for vac in [True, False]:
|
||||||
|
idx += 1
|
||||||
|
vac_str = "VAC=on" if vac else "VAC=off"
|
||||||
|
desc = f"{combo['backend']} / {combo['policy']}"
|
||||||
|
if combo["model"]:
|
||||||
|
desc += f" / {combo['model']}"
|
||||||
|
desc += f" / {vac_str}"
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"[{idx}/{total}] {desc}")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
|
||||||
|
results = await run_single_combo(
|
||||||
|
combo, audio_files, vac=vac, lan="en", max_duration=max_duration,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
# Free memory between combos
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark")
|
||||||
|
parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos")
|
||||||
|
parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path")
|
||||||
|
parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
system_info = get_system_info()
|
||||||
|
combos = detect_combos(quick=args.quick)
|
||||||
|
audio_files = collect_audio_files()
|
||||||
|
|
||||||
|
print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM")
|
||||||
|
print(f"Backends: {list(system_info['backend_versions'].keys())}")
|
||||||
|
print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}")
|
||||||
|
print(f"Audio files: {[f.name for f in audio_files]}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
all_results = asyncio.run(
|
||||||
|
run_full_benchmark(combos, audio_files, max_duration=args.max_duration)
|
||||||
|
)
|
||||||
|
total_time = time.time() - t0
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"system_info": system_info,
|
||||||
|
"benchmark_date": time.strftime("%Y-%m-%d %H:%M"),
|
||||||
|
"total_benchmark_time_s": round(total_time, 1),
|
||||||
|
"n_combos": len(combos) * 2,
|
||||||
|
"n_audio_files": len(audio_files),
|
||||||
|
"results": all_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
||||||
|
print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
580
scripts/python_support_matrix.py
Normal file
580
scripts/python_support_matrix.py
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Offline Python support matrix runner for WhisperLiveKit."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
HAS_RICH = True
|
||||||
|
except Exception:
|
||||||
|
HAS_RICH = False
|
||||||
|
|
||||||
|
SAMPLE_URL = (
|
||||||
|
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
|
||||||
|
)
|
||||||
|
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
|
||||||
|
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
|
||||||
|
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
|
||||||
|
CONSOLE = Console() if HAS_RICH else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MatrixRow:
|
||||||
|
row_id: str
|
||||||
|
extras: tuple[str, ...]
|
||||||
|
backend: str
|
||||||
|
policy: str
|
||||||
|
diarization_backend: str
|
||||||
|
requires_gpu: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
CASES = (
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-diart-cpu",
|
||||||
|
extras=("test", "cpu", "diarization-diart"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="diart",
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-sortformer-cpu",
|
||||||
|
extras=("test", "cpu", "diarization-sortformer"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="sortformer",
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="fw-sortformer-gpu",
|
||||||
|
extras=("test", "cu129", "diarization-sortformer"),
|
||||||
|
backend="faster-whisper",
|
||||||
|
policy="simulstreaming",
|
||||||
|
diarization_backend="sortformer",
|
||||||
|
requires_gpu=True,
|
||||||
|
),
|
||||||
|
MatrixRow(
|
||||||
|
row_id="voxtral-diart-cpu",
|
||||||
|
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
|
||||||
|
backend="voxtral",
|
||||||
|
policy="voxtral",
|
||||||
|
diarization_backend="diart",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_FAILURE_CASES = {
|
||||||
|
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||||
|
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||||
|
}
|
||||||
|
UNSUPPORTED_CASES = {
|
||||||
|
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
|
||||||
|
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CaseResult:
|
||||||
|
python_version: str
|
||||||
|
row_id: str
|
||||||
|
status: Literal["PASS", "FAIL", "N/A"]
|
||||||
|
reason: str
|
||||||
|
duration_sec: float
|
||||||
|
hint: str = ""
|
||||||
|
log_path: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Minimal WhisperLiveKit offline support matrix"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout-sec",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Per-case timeout in seconds (default: 300)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
default=str(DEFAULT_LOGS_DIR),
|
||||||
|
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def safe_slug(text: str) -> str:
|
||||||
|
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def status_style(status: str) -> str:
|
||||||
|
if status == "PASS":
|
||||||
|
return "green"
|
||||||
|
if status == "FAIL":
|
||||||
|
return "bold red"
|
||||||
|
if status == "N/A":
|
||||||
|
return "yellow"
|
||||||
|
return "white"
|
||||||
|
|
||||||
|
|
||||||
|
def print_line(message: str, style: str | None = None) -> None:
|
||||||
|
if CONSOLE is None:
|
||||||
|
print(message)
|
||||||
|
return
|
||||||
|
if style:
|
||||||
|
CONSOLE.print(message, style=style, highlight=False)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(message, highlight=False)
|
||||||
|
|
||||||
|
|
||||||
|
def tail_text(text: str | None, max_chars: int = 220) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
normalized = " ".join(text.split())
|
||||||
|
if len(normalized) <= max_chars:
|
||||||
|
return normalized
|
||||||
|
return normalized[-max_chars:]
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(
|
||||||
|
cmd: list[str],
|
||||||
|
cwd: Path,
|
||||||
|
env: dict[str, str],
|
||||||
|
timeout: int | None = None,
|
||||||
|
log_path: Path | None = None,
|
||||||
|
log_section: str | None = None,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
def _append_log(
|
||||||
|
*,
|
||||||
|
command: list[str],
|
||||||
|
section: str,
|
||||||
|
returncode: int | None,
|
||||||
|
stdout: str | None,
|
||||||
|
stderr: str | None,
|
||||||
|
timed_out: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if log_path is None:
|
||||||
|
return
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with log_path.open("a", encoding="utf-8") as f:
|
||||||
|
f.write(f"\n=== {section} ===\n")
|
||||||
|
f.write(f"$ {shlex.join(command)}\n")
|
||||||
|
if timed_out:
|
||||||
|
f.write("status: timeout\n")
|
||||||
|
else:
|
||||||
|
f.write(f"status: exit_code={returncode}\n")
|
||||||
|
if stdout:
|
||||||
|
f.write("--- stdout ---\n")
|
||||||
|
f.write(stdout)
|
||||||
|
if not stdout.endswith("\n"):
|
||||||
|
f.write("\n")
|
||||||
|
if stderr:
|
||||||
|
f.write("--- stderr ---\n")
|
||||||
|
f.write(stderr)
|
||||||
|
if not stderr.endswith("\n"):
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
section = log_section or "command"
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=str(cwd),
|
||||||
|
env=env,
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as exc:
|
||||||
|
_append_log(
|
||||||
|
command=cmd,
|
||||||
|
section=section,
|
||||||
|
returncode=None,
|
||||||
|
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
|
||||||
|
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
|
||||||
|
timed_out=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
_append_log(
|
||||||
|
command=cmd,
|
||||||
|
section=section,
|
||||||
|
returncode=proc.returncode,
|
||||||
|
stdout=proc.stdout,
|
||||||
|
stderr=proc.stderr,
|
||||||
|
)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def detect_gpu_available() -> bool:
|
||||||
|
try:
|
||||||
|
proc = subprocess.run(
|
||||||
|
["nvidia-smi", "-L"],
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return False
|
||||||
|
return proc.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def download_sample(repo_root: Path) -> Path:
|
||||||
|
target = repo_root / SAMPLE_PATH
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
cmd = [
|
||||||
|
"curl",
|
||||||
|
"--fail",
|
||||||
|
"--location",
|
||||||
|
"--silent",
|
||||||
|
"--show-error",
|
||||||
|
SAMPLE_URL,
|
||||||
|
"--output",
|
||||||
|
str(target),
|
||||||
|
]
|
||||||
|
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
|
||||||
|
if proc.returncode != 0:
|
||||||
|
hint = tail_text(proc.stderr or proc.stdout)
|
||||||
|
raise RuntimeError(f"sample_download_failed: {hint}")
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def sync_case_environment(
|
||||||
|
repo_root: Path,
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
env_dir: Path,
|
||||||
|
log_path: Path,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
|
||||||
|
for extra in row.extras:
|
||||||
|
cmd.extend(["--extra", extra])
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||||
|
proc = run_command(
|
||||||
|
cmd,
|
||||||
|
cwd=repo_root,
|
||||||
|
env=env,
|
||||||
|
log_path=log_path,
|
||||||
|
log_section="sync",
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
return False, tail_text(proc.stderr or proc.stdout)
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
|
||||||
|
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
|
||||||
|
if result.status != "FAIL" or not expected_reason:
|
||||||
|
return result
|
||||||
|
override_hint = result.hint
|
||||||
|
if result.reason:
|
||||||
|
override_hint = (
|
||||||
|
f"expected_failure_override original_reason={result.reason}; {override_hint}"
|
||||||
|
if override_hint
|
||||||
|
else f"expected_failure_override original_reason={result.reason}"
|
||||||
|
)
|
||||||
|
return CaseResult(
|
||||||
|
python_version=result.python_version,
|
||||||
|
row_id=result.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason=expected_reason,
|
||||||
|
duration_sec=result.duration_sec,
|
||||||
|
hint=override_hint,
|
||||||
|
log_path=result.log_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_offline_command(
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
sample_audio: Path,
|
||||||
|
timeout_sec: int,
|
||||||
|
) -> tuple[list[str], int | None]:
|
||||||
|
base_cmd = [
|
||||||
|
"uv",
|
||||||
|
"run",
|
||||||
|
"--python",
|
||||||
|
python_version,
|
||||||
|
"--no-sync",
|
||||||
|
"python",
|
||||||
|
"test_backend_offline.py",
|
||||||
|
"--backend",
|
||||||
|
row.backend,
|
||||||
|
"--policy",
|
||||||
|
row.policy,
|
||||||
|
"--audio",
|
||||||
|
str(sample_audio),
|
||||||
|
"--model",
|
||||||
|
"tiny",
|
||||||
|
"--diarization",
|
||||||
|
"--diarization-backend",
|
||||||
|
row.diarization_backend,
|
||||||
|
"--lan",
|
||||||
|
"en",
|
||||||
|
"--no-realtime",
|
||||||
|
]
|
||||||
|
if shutil.which("timeout"):
|
||||||
|
return ["timeout", str(timeout_sec), *base_cmd], None
|
||||||
|
return base_cmd, timeout_sec
|
||||||
|
|
||||||
|
|
||||||
|
def run_case(
|
||||||
|
repo_root: Path,
|
||||||
|
python_version: str,
|
||||||
|
row: MatrixRow,
|
||||||
|
sample_audio: Path,
|
||||||
|
timeout_sec: int,
|
||||||
|
gpu_available: bool,
|
||||||
|
logs_dir: Path,
|
||||||
|
) -> CaseResult:
|
||||||
|
start = time.monotonic()
|
||||||
|
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
|
||||||
|
log_path = logs_dir / f"run-{case_slug}.log"
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
log_path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
|
||||||
|
if unsupported_reason:
|
||||||
|
log_path.write_text(
|
||||||
|
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason=unsupported_reason,
|
||||||
|
duration_sec=0.0,
|
||||||
|
hint="unsupported_case_precheck",
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
if row.requires_gpu and not gpu_available:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="N/A",
|
||||||
|
reason="gpu_unavailable",
|
||||||
|
duration_sec=0.0,
|
||||||
|
hint="nvidia-smi unavailable or failed",
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
|
||||||
|
sync_ok, sync_hint = sync_case_environment(
|
||||||
|
repo_root,
|
||||||
|
python_version,
|
||||||
|
row,
|
||||||
|
env_dir,
|
||||||
|
log_path=log_path,
|
||||||
|
)
|
||||||
|
if not sync_ok:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason="dependency_sync_failed",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=sync_hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd, process_timeout = build_offline_command(
|
||||||
|
python_version, row, sample_audio, timeout_sec
|
||||||
|
)
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||||
|
if row.requires_gpu:
|
||||||
|
env.pop("CUDA_VISIBLE_DEVICES", None)
|
||||||
|
else:
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
try:
|
||||||
|
proc = run_command(
|
||||||
|
cmd,
|
||||||
|
cwd=repo_root,
|
||||||
|
env=env,
|
||||||
|
timeout=process_timeout,
|
||||||
|
log_path=log_path,
|
||||||
|
log_section="offline",
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired as exc:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason="offline_timeout",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
hint = tail_text(proc.stderr or proc.stdout)
|
||||||
|
if proc.returncode == 0:
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="PASS",
|
||||||
|
reason="ok",
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
|
||||||
|
return CaseResult(
|
||||||
|
python_version=python_version,
|
||||||
|
row_id=row.row_id,
|
||||||
|
status="FAIL",
|
||||||
|
reason=reason,
|
||||||
|
duration_sec=round(time.monotonic() - start, 3),
|
||||||
|
hint=hint,
|
||||||
|
log_path=str(log_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(results: list[CaseResult]) -> None:
|
||||||
|
pass_count = sum(1 for row in results if row.status == "PASS")
|
||||||
|
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||||
|
na_count = sum(1 for row in results if row.status == "N/A")
|
||||||
|
if CONSOLE is None:
|
||||||
|
print("\n[matrix] results")
|
||||||
|
print("python | row | status | reason | duration_s")
|
||||||
|
print("---|---|---|---|---")
|
||||||
|
for result in results:
|
||||||
|
print(
|
||||||
|
f"{result.python_version} | {result.row_id} | {result.status} | "
|
||||||
|
f"{result.reason} | {result.duration_sec:.3f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
|
||||||
|
f"na={na_count} total={len(results)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table = Table(title="Support Matrix Results")
|
||||||
|
table.add_column("Python", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Row", style="white")
|
||||||
|
table.add_column("Status", no_wrap=True)
|
||||||
|
table.add_column("Reason")
|
||||||
|
table.add_column("Duration (s)", justify="right", no_wrap=True)
|
||||||
|
for result in results:
|
||||||
|
table.add_row(
|
||||||
|
result.python_version,
|
||||||
|
result.row_id,
|
||||||
|
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
|
||||||
|
result.reason,
|
||||||
|
f"{result.duration_sec:.3f}",
|
||||||
|
)
|
||||||
|
CONSOLE.print()
|
||||||
|
CONSOLE.print(table)
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[bold]Summary[/bold] "
|
||||||
|
f"pass=[green]{pass_count}[/green] "
|
||||||
|
f"fail=[bold red]{fail_count}[/bold red] "
|
||||||
|
f"na=[yellow]{na_count}[/yellow] "
|
||||||
|
f"total={len(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
|
||||||
|
if diagnostics:
|
||||||
|
if CONSOLE is None:
|
||||||
|
print("\n[matrix] diagnostics (failed/n-a cases)")
|
||||||
|
for row in diagnostics:
|
||||||
|
print(
|
||||||
|
f"- py={row.python_version} row={row.row_id} "
|
||||||
|
f"status={row.status} reason={row.reason}"
|
||||||
|
)
|
||||||
|
print(f" hint: {row.hint}")
|
||||||
|
if row.log_path:
|
||||||
|
print(f" log: {row.log_path}")
|
||||||
|
else:
|
||||||
|
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
|
||||||
|
diagnostics_table.add_column("Case", style="cyan")
|
||||||
|
diagnostics_table.add_column("Status", no_wrap=True)
|
||||||
|
diagnostics_table.add_column("Reason")
|
||||||
|
diagnostics_table.add_column("Hint")
|
||||||
|
diagnostics_table.add_column("Log")
|
||||||
|
for row in diagnostics:
|
||||||
|
diagnostics_table.add_row(
|
||||||
|
f"py={row.python_version} {row.row_id}",
|
||||||
|
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
|
||||||
|
row.reason,
|
||||||
|
row.hint,
|
||||||
|
row.log_path,
|
||||||
|
)
|
||||||
|
CONSOLE.print()
|
||||||
|
CONSOLE.print(diagnostics_table)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
if args.timeout_sec <= 0:
|
||||||
|
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
logs_dir = (repo_root / args.logs_dir).resolve()
|
||||||
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
|
||||||
|
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
|
||||||
|
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
|
||||||
|
|
||||||
|
try:
|
||||||
|
sample_audio = download_sample(repo_root)
|
||||||
|
except Exception as exc: # pragma: no cover - straightforward failure path
|
||||||
|
if CONSOLE is None:
|
||||||
|
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[matrix] sample_download_failed: {exc}",
|
||||||
|
style="bold red",
|
||||||
|
highlight=False,
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
|
||||||
|
|
||||||
|
gpu_available = detect_gpu_available()
|
||||||
|
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
|
||||||
|
|
||||||
|
results: list[CaseResult] = []
|
||||||
|
for python_version in PYTHON_VERSIONS:
|
||||||
|
for row in CASES:
|
||||||
|
print_line(
|
||||||
|
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
|
||||||
|
)
|
||||||
|
result = run_case(
|
||||||
|
repo_root=repo_root,
|
||||||
|
python_version=python_version,
|
||||||
|
row=row,
|
||||||
|
sample_audio=sample_audio,
|
||||||
|
timeout_sec=args.timeout_sec,
|
||||||
|
gpu_available=gpu_available,
|
||||||
|
logs_dir=logs_dir,
|
||||||
|
)
|
||||||
|
result = apply_expected_failure_policy(result)
|
||||||
|
results.append(result)
|
||||||
|
print_line(
|
||||||
|
f"[matrix] {result.status} py={result.python_version} "
|
||||||
|
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
|
||||||
|
style=status_style(result.status),
|
||||||
|
)
|
||||||
|
if result.log_path:
|
||||||
|
print_line(f"[matrix] log={result.log_path}", style="dim")
|
||||||
|
|
||||||
|
print_summary(results)
|
||||||
|
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||||
|
return 1 if fail_count else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
803
test_backend_offline.py
Normal file
803
test_backend_offline.py
Normal file
@@ -0,0 +1,803 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Offline test harness and benchmark suite for WhisperLiveKit backends.
|
||||||
|
|
||||||
|
Simulates a client-server session by feeding audio files as PCM bytes through
|
||||||
|
the full AudioProcessor pipeline (the same path used by the WebSocket server),
|
||||||
|
without needing a browser or microphone.
|
||||||
|
|
||||||
|
Computes WER (Word Error Rate) and timestamp accuracy when ground truth
|
||||||
|
transcript files (.transcript.json) are available alongside audio files.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Test with a single audio file:
|
||||||
|
python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav
|
||||||
|
|
||||||
|
# Test all files in audio_tests/:
|
||||||
|
python test_backend_offline.py --backend faster-whisper --no-realtime
|
||||||
|
|
||||||
|
# Override streaming policy:
|
||||||
|
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime
|
||||||
|
|
||||||
|
# Multi-backend benchmark (auto-detects all installed backends):
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime
|
||||||
|
|
||||||
|
# Export results as JSON:
|
||||||
|
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||||
|
|
||||||
|
# Insert silence for testing silence handling:
|
||||||
|
python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARNING,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("test_offline")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
|
CACHE_DIR = Path(__file__).parent / ".test_cache"
|
||||||
|
AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests"
|
||||||
|
AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordTimestamp:
|
||||||
|
"""Word with its start/end time."""
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestResult:
|
||||||
|
"""Structured result from a single test run."""
|
||||||
|
audio_file: str
|
||||||
|
audio_duration_s: float
|
||||||
|
backend: str
|
||||||
|
policy: str
|
||||||
|
language: str
|
||||||
|
chunk_ms: int
|
||||||
|
realtime_pacing: bool
|
||||||
|
# Timing
|
||||||
|
processing_time_s: float
|
||||||
|
rtf: float # real-time factor
|
||||||
|
# Transcription output
|
||||||
|
transcription: str
|
||||||
|
n_lines: int
|
||||||
|
n_responses: int
|
||||||
|
# WER metrics (None if no ground truth)
|
||||||
|
wer: Optional[float] = None
|
||||||
|
wer_details: Optional[dict] = None
|
||||||
|
# Timestamp accuracy (None if no ground truth)
|
||||||
|
timestamp_mae: Optional[float] = None
|
||||||
|
timestamp_max_delta: Optional[float] = None
|
||||||
|
timestamp_median_delta: Optional[float] = None
|
||||||
|
# Word-level timestamps
|
||||||
|
word_timestamps: List[WordTimestamp] = field(default_factory=list)
|
||||||
|
# Raw last response
|
||||||
|
last_response: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def download_sample_audio() -> Path:
|
||||||
|
"""Download the jfk.wav sample if not cached."""
|
||||||
|
CACHE_DIR.mkdir(exist_ok=True)
|
||||||
|
path = CACHE_DIR / "jfk.wav"
|
||||||
|
if not path.exists():
|
||||||
|
logger.info(f"Downloading sample audio to {path} ...")
|
||||||
|
urllib.request.urlretrieve(JFK_WAV_URL, path)
|
||||||
|
logger.info("Done.")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(path: str) -> np.ndarray:
|
||||||
|
"""Load audio file as float32 mono 16kHz numpy array.
|
||||||
|
|
||||||
|
Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa).
|
||||||
|
"""
|
||||||
|
ext = Path(path).suffix.lower()
|
||||||
|
if ext in (".mp3", ".ogg", ".m4a"):
|
||||||
|
import librosa
|
||||||
|
audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
|
||||||
|
return audio.astype(np.float32)
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
audio, sr = sf.read(path, dtype="float32")
|
||||||
|
if audio.ndim > 1:
|
||||||
|
audio = audio.mean(axis=1)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
import librosa
|
||||||
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray:
|
||||||
|
"""Insert silence into audio at a given position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Float32 mono audio array at SAMPLE_RATE.
|
||||||
|
silence_sec: Duration of silence to insert in seconds.
|
||||||
|
position_sec: Position in seconds where silence starts.
|
||||||
|
Returns:
|
||||||
|
New audio array with silence inserted.
|
||||||
|
"""
|
||||||
|
pos_samples = int(position_sec * SAMPLE_RATE)
|
||||||
|
silence_samples = int(silence_sec * SAMPLE_RATE)
|
||||||
|
pos_samples = min(pos_samples, len(audio))
|
||||||
|
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]])
|
||||||
|
|
||||||
|
|
||||||
|
def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
|
||||||
|
"""Convert float32 audio to s16le PCM bytes (what the browser sends)."""
|
||||||
|
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def create_engine(
|
||||||
|
backend: str, model_size: str, lan: str,
|
||||||
|
diarization: bool = False,
|
||||||
|
diarization_backend: str = "",
|
||||||
|
vac: bool = True,
|
||||||
|
policy: str = "",
|
||||||
|
):
|
||||||
|
"""Create a TranscriptionEngine with the given backend config."""
|
||||||
|
import gc
|
||||||
|
from whisperlivekit.core import TranscriptionEngine
|
||||||
|
|
||||||
|
# Reset singleton so we get a fresh instance
|
||||||
|
TranscriptionEngine._instance = None
|
||||||
|
TranscriptionEngine._initialized = False
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
backend=backend,
|
||||||
|
lan=lan,
|
||||||
|
pcm_input=True,
|
||||||
|
vac=vac,
|
||||||
|
transcription=True,
|
||||||
|
diarization=diarization,
|
||||||
|
)
|
||||||
|
if diarization_backend:
|
||||||
|
kwargs["diarization_backend"] = diarization_backend
|
||||||
|
if model_size:
|
||||||
|
kwargs["model_size"] = model_size
|
||||||
|
if policy:
|
||||||
|
kwargs["backend_policy"] = policy
|
||||||
|
|
||||||
|
return TranscriptionEngine(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_response(response_dict: dict) -> str:
|
||||||
|
"""Extract full transcription text from a FrontData dict."""
|
||||||
|
def _strip_or_empty(value: object) -> str:
|
||||||
|
return value.strip() if isinstance(value, str) else ""
|
||||||
|
|
||||||
|
segments = response_dict.get("lines", [])
|
||||||
|
full_text = " ".join(
|
||||||
|
text
|
||||||
|
for seg in segments
|
||||||
|
if isinstance(seg, dict)
|
||||||
|
for text in [_strip_or_empty(seg.get("text"))]
|
||||||
|
if text
|
||||||
|
)
|
||||||
|
buf = _strip_or_empty(response_dict.get("buffer_transcription"))
|
||||||
|
if buf:
|
||||||
|
full_text = f"{full_text} {buf}".strip() if full_text else buf
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
|
||||||
|
async def run_test(
|
||||||
|
engine, audio: np.ndarray, chunk_ms: int, realtime: bool,
|
||||||
|
audio_file: str = "", backend: str = "", policy: str = "", lan: str = "",
|
||||||
|
) -> TestResult:
|
||||||
|
"""
|
||||||
|
Simulate a client session through the full AudioProcessor pipeline.
|
||||||
|
|
||||||
|
1. Create AudioProcessor (one per "client session")
|
||||||
|
2. Start async pipeline (transcription_processor, results_formatter, etc.)
|
||||||
|
3. Feed audio as PCM bytes in timed chunks
|
||||||
|
4. Collect and display FrontData responses
|
||||||
|
5. Signal EOF and cleanup
|
||||||
|
"""
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000)
|
||||||
|
total_samples = len(audio)
|
||||||
|
audio_duration = total_samples / SAMPLE_RATE
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Audio: {audio_duration:.2f}s | "
|
||||||
|
f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | "
|
||||||
|
f"Steps: {total_samples // chunk_samples + 1} | "
|
||||||
|
f"Realtime: {realtime}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Server side: create processor and start pipeline ---
|
||||||
|
processor = AudioProcessor(transcription_engine=engine)
|
||||||
|
results_generator = await processor.create_tasks()
|
||||||
|
|
||||||
|
# Collect results in background (like handle_websocket_results)
|
||||||
|
all_responses = []
|
||||||
|
response_count = 0
|
||||||
|
last_printed_text = ""
|
||||||
|
|
||||||
|
async def collect_results():
|
||||||
|
nonlocal response_count, last_printed_text
|
||||||
|
async for response in results_generator:
|
||||||
|
all_responses.append(response)
|
||||||
|
response_count += 1
|
||||||
|
d = response.to_dict()
|
||||||
|
|
||||||
|
# Only print when transcription text actually changes
|
||||||
|
current_text = _extract_text_from_response(d)
|
||||||
|
if current_text and current_text != last_printed_text:
|
||||||
|
buf = d.get("buffer_transcription")
|
||||||
|
buf = buf.strip() if isinstance(buf, str) else ""
|
||||||
|
committed = current_text
|
||||||
|
if buf and committed.endswith(buf):
|
||||||
|
committed = committed[:-len(buf)].strip()
|
||||||
|
|
||||||
|
# Show committed text + buffer separately
|
||||||
|
display = committed
|
||||||
|
if buf:
|
||||||
|
display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m"
|
||||||
|
print(f" > {display}", flush=True)
|
||||||
|
last_printed_text = current_text
|
||||||
|
|
||||||
|
result_task = asyncio.create_task(collect_results())
|
||||||
|
|
||||||
|
# --- Client side: feed audio as PCM bytes ---
|
||||||
|
t_start = time.time()
|
||||||
|
|
||||||
|
for offset in range(0, total_samples, chunk_samples):
|
||||||
|
chunk = audio[offset : offset + chunk_samples]
|
||||||
|
pcm_bytes = float32_to_s16le_bytes(chunk)
|
||||||
|
await processor.process_audio(pcm_bytes)
|
||||||
|
if realtime:
|
||||||
|
await asyncio.sleep(chunk_ms / 1000)
|
||||||
|
|
||||||
|
feed_elapsed = time.time() - t_start
|
||||||
|
|
||||||
|
logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...")
|
||||||
|
|
||||||
|
# Signal end of audio (like client disconnect / empty message)
|
||||||
|
await processor.process_audio(None)
|
||||||
|
|
||||||
|
# Wait for pipeline to drain completely
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(result_task, timeout=120.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Timed out waiting for results. Proceeding with cleanup.")
|
||||||
|
result_task.cancel()
|
||||||
|
try:
|
||||||
|
await result_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- Capture word-level timestamps before cleanup ---
|
||||||
|
word_timestamps = []
|
||||||
|
try:
|
||||||
|
state = await processor.get_current_state()
|
||||||
|
for token in state.tokens:
|
||||||
|
if hasattr(token, 'start') and hasattr(token, 'text') and token.text:
|
||||||
|
word_timestamps.append(WordTimestamp(
|
||||||
|
word=token.text.strip(),
|
||||||
|
start=round(token.start, 3),
|
||||||
|
end=round(token.end, 3),
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not capture word timestamps: {e}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await processor.cleanup()
|
||||||
|
|
||||||
|
total_elapsed = time.time() - t_start
|
||||||
|
|
||||||
|
# --- Build result ---
|
||||||
|
transcription = ""
|
||||||
|
n_lines = 0
|
||||||
|
last_response_dict = None
|
||||||
|
|
||||||
|
if all_responses:
|
||||||
|
last = all_responses[-1].to_dict()
|
||||||
|
last_response_dict = last
|
||||||
|
n_lines = len(last.get("lines", []))
|
||||||
|
transcription = _extract_text_from_response(last)
|
||||||
|
|
||||||
|
# --- Compute WER and timestamp accuracy against ground truth ---
|
||||||
|
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
|
||||||
|
|
||||||
|
wer_val = None
|
||||||
|
wer_details = None
|
||||||
|
ts_mae = None
|
||||||
|
ts_max_delta = None
|
||||||
|
ts_median_delta = None
|
||||||
|
|
||||||
|
gt_path = Path(audio_file).with_suffix(".transcript.json")
|
||||||
|
if not gt_path.exists():
|
||||||
|
gt_path = AUDIO_TESTS_DIR / gt_path
|
||||||
|
gt = None
|
||||||
|
if gt_path.exists():
|
||||||
|
with open(gt_path) as f:
|
||||||
|
gt = json.load(f)
|
||||||
|
|
||||||
|
# WER
|
||||||
|
gt_text = " ".join(w["word"] for w in gt)
|
||||||
|
wer_result = compute_wer(gt_text, transcription)
|
||||||
|
wer_val = round(wer_result["wer"], 4)
|
||||||
|
wer_details = wer_result
|
||||||
|
|
||||||
|
# Timestamp accuracy
|
||||||
|
if word_timestamps:
|
||||||
|
pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps]
|
||||||
|
ts_result = compute_timestamp_accuracy(pred_dicts, gt)
|
||||||
|
ts_mae = ts_result["mae_start"]
|
||||||
|
ts_max_delta = ts_result["max_delta_start"]
|
||||||
|
ts_median_delta = ts_result["median_delta_start"]
|
||||||
|
|
||||||
|
result = TestResult(
|
||||||
|
audio_file=audio_file,
|
||||||
|
audio_duration_s=round(audio_duration, 2),
|
||||||
|
backend=backend,
|
||||||
|
policy=policy,
|
||||||
|
language=lan,
|
||||||
|
chunk_ms=chunk_ms,
|
||||||
|
realtime_pacing=realtime,
|
||||||
|
processing_time_s=round(total_elapsed, 2),
|
||||||
|
rtf=round(total_elapsed / audio_duration, 2),
|
||||||
|
transcription=transcription,
|
||||||
|
n_lines=n_lines,
|
||||||
|
n_responses=response_count,
|
||||||
|
wer=wer_val,
|
||||||
|
wer_details=wer_details,
|
||||||
|
timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None,
|
||||||
|
timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None,
|
||||||
|
timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None,
|
||||||
|
word_timestamps=word_timestamps,
|
||||||
|
last_response=last_response_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Print summary ---
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"RESULT: {audio_file}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"Transcription: {transcription}")
|
||||||
|
print(f"Lines: {n_lines} | Responses: {response_count}")
|
||||||
|
print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x")
|
||||||
|
|
||||||
|
if wer_val is not None:
|
||||||
|
print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})")
|
||||||
|
|
||||||
|
# Print word timestamps if available
|
||||||
|
if word_timestamps:
|
||||||
|
print(f"\nWord timestamps ({len(word_timestamps)} words):")
|
||||||
|
for wt in word_timestamps:
|
||||||
|
print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}")
|
||||||
|
|
||||||
|
# Detailed comparison with ground truth
|
||||||
|
if gt:
|
||||||
|
print(f"\n vs Ground truth ({len(gt)} words):")
|
||||||
|
max_words = max(len(word_timestamps), len(gt))
|
||||||
|
for i in range(max_words):
|
||||||
|
pred = word_timestamps[i] if i < len(word_timestamps) else None
|
||||||
|
ref = gt[i] if i < len(gt) else None
|
||||||
|
p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30
|
||||||
|
r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else ""
|
||||||
|
delta = ""
|
||||||
|
if pred and ref:
|
||||||
|
d = pred.start - ref['start']
|
||||||
|
delta = f" Δstart={d:+.2f}"
|
||||||
|
print(f" {p_str} | {r_str}{delta}")
|
||||||
|
|
||||||
|
if ts_mae is not None:
|
||||||
|
print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s")
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def discover_audio_files(directory: str) -> List[Path]:
|
||||||
|
"""Find all supported audio files in directory."""
|
||||||
|
d = Path(directory)
|
||||||
|
files = sorted(
|
||||||
|
p for p in d.iterdir()
|
||||||
|
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
|
||||||
|
)
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
async def run_all_tests(
|
||||||
|
engine, audio_files: List[Path], chunk_ms: int, realtime: bool,
|
||||||
|
backend: str, policy: str, lan: str, max_duration: float = 60.0,
|
||||||
|
silence_insertions: Optional[List[List[float]]] = None,
|
||||||
|
) -> List[TestResult]:
|
||||||
|
"""Run tests on multiple audio files sequentially."""
|
||||||
|
results = []
|
||||||
|
for audio_path in audio_files:
|
||||||
|
# Detect language from filename if "french" in name
|
||||||
|
file_lan = lan
|
||||||
|
if "french" in audio_path.name.lower() and lan == "en":
|
||||||
|
file_lan = "fr"
|
||||||
|
logger.info(f"Auto-detected language 'fr' from filename")
|
||||||
|
|
||||||
|
audio = load_audio(str(audio_path))
|
||||||
|
|
||||||
|
# Insert silence segments (applied in reverse position order to keep offsets valid)
|
||||||
|
if silence_insertions:
|
||||||
|
for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True):
|
||||||
|
logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s")
|
||||||
|
audio = insert_silence(audio, secs, at_sec)
|
||||||
|
|
||||||
|
duration = len(audio) / SAMPLE_RATE
|
||||||
|
|
||||||
|
if duration > max_duration:
|
||||||
|
logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n{'#' * 60}")
|
||||||
|
print(f"# Testing: {audio_path.name} ({duration:.1f}s)")
|
||||||
|
print(f"{'#' * 60}")
|
||||||
|
|
||||||
|
result = await run_test(
|
||||||
|
engine, audio, chunk_ms, realtime,
|
||||||
|
audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_benchmark_summary(results: List[TestResult]):
|
||||||
|
"""Print a tabular summary of all test results."""
|
||||||
|
print(f"\n{'=' * 110}")
|
||||||
|
print("BENCHMARK SUMMARY")
|
||||||
|
print(f"{'=' * 110}")
|
||||||
|
print(
|
||||||
|
f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} "
|
||||||
|
f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}"
|
||||||
|
)
|
||||||
|
print(f"{'-' * 110}")
|
||||||
|
for r in results:
|
||||||
|
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
|
||||||
|
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
|
||||||
|
print(
|
||||||
|
f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s "
|
||||||
|
f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}"
|
||||||
|
)
|
||||||
|
print(f"{'-' * 110}")
|
||||||
|
total_audio = sum(r.audio_duration_s for r in results)
|
||||||
|
total_time = sum(r.processing_time_s for r in results)
|
||||||
|
avg_rtf = total_time / total_audio if total_audio > 0 else 0
|
||||||
|
wer_vals = [r.wer for r in results if r.wer is not None]
|
||||||
|
avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
|
||||||
|
mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None]
|
||||||
|
avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
|
||||||
|
print(
|
||||||
|
f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s "
|
||||||
|
f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}"
|
||||||
|
)
|
||||||
|
print(f"{'=' * 110}")
|
||||||
|
|
||||||
|
# Print transcription excerpts
|
||||||
|
print(f"\nTRANSCRIPTIONS:")
|
||||||
|
print(f"{'-' * 110}")
|
||||||
|
for r in results:
|
||||||
|
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
|
||||||
|
print(f" {r.audio_file}:")
|
||||||
|
print(f" {excerpt}")
|
||||||
|
print(f"{'=' * 110}")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_available_backends() -> List[dict]:
|
||||||
|
"""Probe which backends can be imported and return (backend, policy) combos.
|
||||||
|
|
||||||
|
Returns list of dicts with keys: backend, policy, description.
|
||||||
|
"""
|
||||||
|
combos = []
|
||||||
|
|
||||||
|
# faster-whisper
|
||||||
|
try:
|
||||||
|
import faster_whisper # noqa: F401
|
||||||
|
combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"})
|
||||||
|
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# mlx-whisper (macOS only)
|
||||||
|
try:
|
||||||
|
import mlx_whisper # noqa: F401
|
||||||
|
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"})
|
||||||
|
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# openai-whisper
|
||||||
|
try:
|
||||||
|
import whisper # noqa: F401
|
||||||
|
combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"})
|
||||||
|
combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# voxtral-mlx
|
||||||
|
try:
|
||||||
|
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
|
||||||
|
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# voxtral (HuggingFace)
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
|
||||||
|
combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"})
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return combos
|
||||||
|
|
||||||
|
|
||||||
|
def print_cross_backend_comparison(all_results: List[TestResult]):
|
||||||
|
"""Print a comparison table across backends and policies."""
|
||||||
|
print(f"\n{'=' * 110}")
|
||||||
|
print("CROSS-BACKEND BENCHMARK COMPARISON")
|
||||||
|
print(f"{'=' * 110}")
|
||||||
|
print(
|
||||||
|
f"{'Backend':<18} {'Policy':<16} {'File':<30} "
|
||||||
|
f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}"
|
||||||
|
)
|
||||||
|
print(f"{'-' * 110}")
|
||||||
|
|
||||||
|
for r in all_results:
|
||||||
|
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
|
||||||
|
rtf_str = f"{r.rtf:.2f}x"
|
||||||
|
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
|
||||||
|
max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -"
|
||||||
|
# Truncate filename for readability
|
||||||
|
fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file
|
||||||
|
print(
|
||||||
|
f"{r.backend:<18} {r.policy:<16} {fname:<30} "
|
||||||
|
f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'-' * 110}")
|
||||||
|
|
||||||
|
# Per-backend averages
|
||||||
|
from collections import defaultdict
|
||||||
|
by_combo = defaultdict(list)
|
||||||
|
for r in all_results:
|
||||||
|
by_combo[(r.backend, r.policy)].append(r)
|
||||||
|
|
||||||
|
print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}")
|
||||||
|
print(f"{'-' * 80}")
|
||||||
|
for (backend, policy), group in sorted(by_combo.items()):
|
||||||
|
wer_vals = [r.wer for r in group if r.wer is not None]
|
||||||
|
rtf_vals = [r.rtf for r in group]
|
||||||
|
mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None]
|
||||||
|
avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
|
||||||
|
avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x"
|
||||||
|
avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
|
||||||
|
print(
|
||||||
|
f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}"
|
||||||
|
)
|
||||||
|
print(f"{'=' * 110}")
|
||||||
|
|
||||||
|
|
||||||
|
def _quiet_loggers(verbose: bool):
|
||||||
|
"""Set internal module log levels to reduce noise."""
|
||||||
|
if verbose:
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
else:
|
||||||
|
for mod in (
|
||||||
|
"whisperlivekit.audio_processor", "whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base",
|
||||||
|
"whisperlivekit.simul_whisper.simul_whisper",
|
||||||
|
):
|
||||||
|
logging.getLogger(mod).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_benchmark(
|
||||||
|
audio_files: List[Path], chunk_ms: int, realtime: bool,
|
||||||
|
model_size: str, lan: str, max_duration: float, vac: bool,
|
||||||
|
verbose: bool,
|
||||||
|
) -> List[TestResult]:
|
||||||
|
"""Run benchmark across all available backend+policy combinations."""
|
||||||
|
combos = detect_available_backends()
|
||||||
|
if not combos:
|
||||||
|
logger.error("No backends available. Install at least one ASR backend.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(combos)} backend+policy combinations:")
|
||||||
|
for c in combos:
|
||||||
|
logger.info(f" - {c['description']}")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for i, combo in enumerate(combos, 1):
|
||||||
|
backend = combo["backend"]
|
||||||
|
policy = combo["policy"]
|
||||||
|
desc = combo["description"]
|
||||||
|
|
||||||
|
print(f"\n{'*' * 70}")
|
||||||
|
print(f"* BENCHMARK {i}/{len(combos)}: {desc}")
|
||||||
|
print(f"{'*' * 70}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine = create_engine(
|
||||||
|
backend, model_size, lan, vac=vac, policy=policy,
|
||||||
|
)
|
||||||
|
_quiet_loggers(verbose)
|
||||||
|
|
||||||
|
results = await run_all_tests(
|
||||||
|
engine, audio_files, chunk_ms, realtime,
|
||||||
|
backend=backend, policy=policy, lan=lan,
|
||||||
|
max_duration=max_duration,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to run {desc}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Offline backend test harness (AudioProcessor-level)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", default="faster-whisper",
|
||||||
|
help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy", default="",
|
||||||
|
help="Override backend policy: localagreement, simulstreaming, voxtral.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio", default=None,
|
||||||
|
help="Path to a single audio file (WAV, MP3, FLAC, etc.).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio-dir", default=None,
|
||||||
|
help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-ms", type=int, default=100,
|
||||||
|
help="Chunk size in milliseconds (simulates real-time interval).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", default="", dest="model_size",
|
||||||
|
help="Model size or HF repo ID.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--lan", default="en", help="Language code.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-realtime", action="store_true",
|
||||||
|
help="Skip real-time pacing between chunks (faster but less realistic).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-vac", action="store_true",
|
||||||
|
help="Disable Voice Activity Classification (send all audio without silence filtering).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization", action="store_true",
|
||||||
|
help="Enable speaker diarization.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
default="",
|
||||||
|
choices=["diart", "sortformer"],
|
||||||
|
help="Diarization backend when --diarization is enabled.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmark", action="store_true",
|
||||||
|
help="Run benchmark across all detected backend+policy combinations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--json", default=None, dest="json_output",
|
||||||
|
help="Write structured JSON results to this file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-duration", type=float, default=60.0,
|
||||||
|
help="Skip audio files longer than this many seconds (default: 60).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"),
|
||||||
|
action="append", default=[],
|
||||||
|
help="Insert SECS of silence at AT_SEC position. Can be repeated. "
|
||||||
|
"E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", "--verbose", action="store_true",
|
||||||
|
help="Show debug-level logs from all components.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
realtime = not args.no_realtime
|
||||||
|
vac = not args.no_vac
|
||||||
|
|
||||||
|
# Resolve audio file(s)
|
||||||
|
if args.audio:
|
||||||
|
audio_files = [Path(args.audio)]
|
||||||
|
elif args.audio_dir:
|
||||||
|
audio_files = discover_audio_files(args.audio_dir)
|
||||||
|
elif AUDIO_TESTS_DIR.is_dir():
|
||||||
|
audio_files = discover_audio_files(str(AUDIO_TESTS_DIR))
|
||||||
|
else:
|
||||||
|
# Fall back to jfk.wav download
|
||||||
|
audio_files = [download_sample_audio()]
|
||||||
|
|
||||||
|
if not audio_files:
|
||||||
|
logger.error("No audio files found.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Audio files: {[f.name for f in audio_files]}")
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
# --- Multi-backend benchmark mode ---
|
||||||
|
all_results = asyncio.run(
|
||||||
|
run_benchmark(
|
||||||
|
audio_files, args.chunk_ms, realtime,
|
||||||
|
args.model_size, args.lan, args.max_duration, vac,
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if all_results:
|
||||||
|
print_cross_backend_comparison(all_results)
|
||||||
|
results = all_results
|
||||||
|
else:
|
||||||
|
# --- Single-backend mode ---
|
||||||
|
policy = args.policy
|
||||||
|
logger.info(f"Creating {args.backend} engine...")
|
||||||
|
engine = create_engine(
|
||||||
|
args.backend, args.model_size, args.lan,
|
||||||
|
diarization=args.diarization,
|
||||||
|
diarization_backend=args.diarization_backend,
|
||||||
|
vac=vac,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
logger.info("Engine ready.")
|
||||||
|
|
||||||
|
_quiet_loggers(args.verbose)
|
||||||
|
|
||||||
|
results = asyncio.run(
|
||||||
|
run_all_tests(
|
||||||
|
engine, audio_files, args.chunk_ms, realtime,
|
||||||
|
args.backend, policy, args.lan,
|
||||||
|
max_duration=args.max_duration,
|
||||||
|
silence_insertions=args.insert_silence or None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(results) > 1:
|
||||||
|
print_benchmark_summary(results)
|
||||||
|
|
||||||
|
# JSON output
|
||||||
|
if args.json_output and results:
|
||||||
|
json_results = []
|
||||||
|
for r in results:
|
||||||
|
d = asdict(r)
|
||||||
|
d.pop("last_response", None) # too verbose for summary
|
||||||
|
json_results.append(d)
|
||||||
|
Path(args.json_output).write_text(
|
||||||
|
json.dumps(json_results, indent=2, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
logger.info(f"Results written to {args.json_output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
58
tests/conftest.py
Normal file
58
tests/conftest.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Shared pytest fixtures for WhisperLiveKit tests."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Silence, Transcript
|
||||||
|
|
||||||
|
|
||||||
|
AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tokens():
|
||||||
|
"""A short sequence of ASRToken objects."""
|
||||||
|
return [
|
||||||
|
ASRToken(start=0.0, end=0.5, text="Hello"),
|
||||||
|
ASRToken(start=0.5, end=1.0, text=" world"),
|
||||||
|
ASRToken(start=1.0, end=1.5, text=" test."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_silence():
|
||||||
|
"""A completed silence event."""
|
||||||
|
s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True)
|
||||||
|
s.compute_duration()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_args():
|
||||||
|
"""Minimal args namespace for AudioProcessor tests."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
diarization=False,
|
||||||
|
transcription=True,
|
||||||
|
target_language="",
|
||||||
|
vac=False,
|
||||||
|
vac_chunk_size=0.04,
|
||||||
|
min_chunk_size=0.1,
|
||||||
|
pcm_input=True,
|
||||||
|
punctuation_split=False,
|
||||||
|
backend="faster-whisper",
|
||||||
|
backend_policy="localagreement",
|
||||||
|
vad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ground_truth_en():
|
||||||
|
"""Ground truth transcript for the 7s English audio (if available)."""
|
||||||
|
path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json"
|
||||||
|
if path.exists():
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
return None
|
||||||
209
tests/test_audio_processor.py
Normal file
209
tests/test_audio_processor.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Tests for AudioProcessor pipeline with mocked ASR backends.
|
||||||
|
|
||||||
|
These tests verify the async audio processing pipeline works correctly
|
||||||
|
without requiring any real ASR models to be loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mock ASR components
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class MockASR:
|
||||||
|
"""Mock ASR model holder."""
|
||||||
|
sep = " "
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
self.original_language = "en"
|
||||||
|
self.backend_choice = "mock"
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MockOnlineProcessor:
|
||||||
|
"""Mock online processor that returns canned tokens."""
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(self, asr=None):
|
||||||
|
self.asr = asr or MockASR()
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.end = 0.0
|
||||||
|
self._call_count = 0
|
||||||
|
self._finished = False
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio, audio_stream_end_time):
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False):
|
||||||
|
self._call_count += 1
|
||||||
|
# Emit a token on every call when we have audio
|
||||||
|
if len(self.audio_buffer) > 0:
|
||||||
|
t = self._call_count * 0.5
|
||||||
|
return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
def start_silence(self):
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration, offset):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self._finished = True
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pcm_bytes(duration_s=0.1, sample_rate=16000):
|
||||||
|
"""Generate silent PCM s16le bytes."""
|
||||||
|
n_samples = int(duration_s * sample_rate)
|
||||||
|
audio = np.zeros(n_samples, dtype=np.float32)
|
||||||
|
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_engine():
|
||||||
|
"""Create a mock TranscriptionEngine-like object."""
|
||||||
|
engine = SimpleNamespace(
|
||||||
|
asr=MockASR(),
|
||||||
|
diarization_model=None,
|
||||||
|
translation_model=None,
|
||||||
|
args=SimpleNamespace(
|
||||||
|
diarization=False,
|
||||||
|
transcription=True,
|
||||||
|
target_language="",
|
||||||
|
vac=False,
|
||||||
|
vac_chunk_size=0.04,
|
||||||
|
min_chunk_size=0.1,
|
||||||
|
pcm_input=True,
|
||||||
|
punctuation_split=False,
|
||||||
|
backend="mock",
|
||||||
|
backend_policy="localagreement",
|
||||||
|
vad=True,
|
||||||
|
model_size="base",
|
||||||
|
lan="en",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPCMConversion:
|
||||||
|
"""Test PCM byte conversion without needing the full pipeline."""
|
||||||
|
|
||||||
|
def test_s16le_roundtrip(self):
|
||||||
|
"""Convert float32 → s16le → float32 and verify approximate roundtrip."""
|
||||||
|
original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32)
|
||||||
|
s16 = (original * 32768).clip(-32768, 32767).astype(np.int16)
|
||||||
|
pcm_bytes = s16.tobytes()
|
||||||
|
# Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float)
|
||||||
|
recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
np.testing.assert_allclose(recovered, original, atol=1 / 32768)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestPipelineBasics:
|
||||||
|
async def test_feed_audio_and_get_responses(self, mock_engine):
|
||||||
|
"""Feed audio through the pipeline and verify we get responses."""
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||||
|
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||||
|
results_gen = await processor.create_tasks()
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
async def collect():
|
||||||
|
async for resp in results_gen:
|
||||||
|
responses.append(resp)
|
||||||
|
|
||||||
|
task = asyncio.create_task(collect())
|
||||||
|
|
||||||
|
# Feed 2 seconds of audio in 100ms chunks
|
||||||
|
for _ in range(20):
|
||||||
|
await processor.process_audio(_make_pcm_bytes(0.1))
|
||||||
|
|
||||||
|
# Signal EOF
|
||||||
|
await processor.process_audio(None)
|
||||||
|
|
||||||
|
await asyncio.wait_for(task, timeout=10.0)
|
||||||
|
await processor.cleanup()
|
||||||
|
|
||||||
|
# We should have gotten at least one response
|
||||||
|
assert len(responses) > 0
|
||||||
|
|
||||||
|
async def test_eof_terminates_pipeline(self, mock_engine):
|
||||||
|
"""Sending None (EOF) should cleanly terminate the pipeline."""
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||||
|
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||||
|
results_gen = await processor.create_tasks()
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
async def collect():
|
||||||
|
async for resp in results_gen:
|
||||||
|
responses.append(resp)
|
||||||
|
|
||||||
|
task = asyncio.create_task(collect())
|
||||||
|
|
||||||
|
# Send a small amount of audio then EOF
|
||||||
|
await processor.process_audio(_make_pcm_bytes(0.5))
|
||||||
|
await processor.process_audio(None)
|
||||||
|
|
||||||
|
await asyncio.wait_for(task, timeout=10.0)
|
||||||
|
await processor.cleanup()
|
||||||
|
|
||||||
|
# Pipeline should have terminated without error
|
||||||
|
assert task.done()
|
||||||
|
|
||||||
|
async def test_empty_audio_no_crash(self, mock_engine):
|
||||||
|
"""Sending EOF immediately (no audio) should not crash."""
|
||||||
|
from whisperlivekit.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||||
|
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||||
|
results_gen = await processor.create_tasks()
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
async def collect():
|
||||||
|
async for resp in results_gen:
|
||||||
|
responses.append(resp)
|
||||||
|
|
||||||
|
task = asyncio.create_task(collect())
|
||||||
|
await processor.process_audio(None)
|
||||||
|
|
||||||
|
await asyncio.wait_for(task, timeout=10.0)
|
||||||
|
await processor.cleanup()
|
||||||
|
assert task.done()
|
||||||
99
tests/test_config.py
Normal file
99
tests/test_config.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Tests for WhisperLiveKitConfig."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.config import WhisperLiveKitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaults:
|
||||||
|
def test_default_backend(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.backend == "auto"
|
||||||
|
|
||||||
|
def test_default_policy(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.backend_policy == "simulstreaming"
|
||||||
|
|
||||||
|
def test_default_language(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.lan == "auto"
|
||||||
|
|
||||||
|
def test_default_vac(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.vac is True
|
||||||
|
|
||||||
|
def test_default_model_size(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.model_size == "base"
|
||||||
|
|
||||||
|
def test_default_transcription(self):
|
||||||
|
c = WhisperLiveKitConfig()
|
||||||
|
assert c.transcription is True
|
||||||
|
assert c.diarization is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostInit:
|
||||||
|
def test_en_model_forces_english(self):
|
||||||
|
c = WhisperLiveKitConfig(model_size="tiny.en")
|
||||||
|
assert c.lan == "en"
|
||||||
|
|
||||||
|
def test_en_suffix_with_auto_language(self):
|
||||||
|
c = WhisperLiveKitConfig(model_size="base.en", lan="auto")
|
||||||
|
assert c.lan == "en"
|
||||||
|
|
||||||
|
def test_non_en_model_keeps_language(self):
|
||||||
|
c = WhisperLiveKitConfig(model_size="base", lan="fr")
|
||||||
|
assert c.lan == "fr"
|
||||||
|
|
||||||
|
def test_policy_alias_1(self):
|
||||||
|
c = WhisperLiveKitConfig(backend_policy="1")
|
||||||
|
assert c.backend_policy == "simulstreaming"
|
||||||
|
|
||||||
|
def test_policy_alias_2(self):
|
||||||
|
c = WhisperLiveKitConfig(backend_policy="2")
|
||||||
|
assert c.backend_policy == "localagreement"
|
||||||
|
|
||||||
|
def test_policy_no_alias(self):
|
||||||
|
c = WhisperLiveKitConfig(backend_policy="localagreement")
|
||||||
|
assert c.backend_policy == "localagreement"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromNamespace:
|
||||||
|
def test_known_keys(self):
|
||||||
|
ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3")
|
||||||
|
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||||
|
assert c.backend == "faster-whisper"
|
||||||
|
assert c.lan == "en"
|
||||||
|
assert c.model_size == "large-v3"
|
||||||
|
|
||||||
|
def test_ignores_unknown_keys(self):
|
||||||
|
ns = SimpleNamespace(backend="auto", unknown_key="value", another="x")
|
||||||
|
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||||
|
assert c.backend == "auto"
|
||||||
|
assert not hasattr(c, "unknown_key")
|
||||||
|
|
||||||
|
def test_preserves_defaults_for_missing(self):
|
||||||
|
ns = SimpleNamespace(backend="voxtral-mlx")
|
||||||
|
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||||
|
assert c.lan == "auto"
|
||||||
|
assert c.vac is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromKwargs:
|
||||||
|
def test_known_keys(self):
|
||||||
|
c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr")
|
||||||
|
assert c.backend == "mlx-whisper"
|
||||||
|
assert c.lan == "fr"
|
||||||
|
|
||||||
|
def test_warns_on_unknown_keys(self, caplog):
|
||||||
|
with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"):
|
||||||
|
c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value")
|
||||||
|
assert c.backend == "auto"
|
||||||
|
assert "bogus" in caplog.text
|
||||||
|
|
||||||
|
def test_post_init_runs(self):
|
||||||
|
c = WhisperLiveKitConfig.from_kwargs(model_size="small.en")
|
||||||
|
assert c.lan == "en"
|
||||||
172
tests/test_hypothesis_buffer.py
Normal file
172
tests/test_hypothesis_buffer.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Tests for HypothesisBuffer — the core of LocalAgreement policy."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.local_agreement.online_asr import HypothesisBuffer
|
||||||
|
|
||||||
|
|
||||||
|
def make_tokens(words, start=0.0, step=0.5):
|
||||||
|
"""Helper: create ASRToken list from word strings."""
|
||||||
|
tokens = []
|
||||||
|
t = start
|
||||||
|
for w in words:
|
||||||
|
tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9))
|
||||||
|
t += step
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsert:
|
||||||
|
def test_basic_insert(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
tokens = make_tokens(["hello", "world"])
|
||||||
|
buf.insert(tokens, offset=0.0)
|
||||||
|
assert len(buf.new) == 2
|
||||||
|
assert buf.new[0].text == "hello"
|
||||||
|
|
||||||
|
def test_insert_with_offset(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
tokens = make_tokens(["hello"], start=0.0)
|
||||||
|
buf.insert(tokens, offset=5.0)
|
||||||
|
assert buf.new[0].start == pytest.approx(5.0)
|
||||||
|
|
||||||
|
def test_insert_filters_old_tokens(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.last_committed_time = 10.0
|
||||||
|
tokens = make_tokens(["old", "new"], start=5.0, step=3.0)
|
||||||
|
buf.insert(tokens, offset=0.0)
|
||||||
|
# "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered
|
||||||
|
# "new" at 8.0 is also before 9.9 → filtered
|
||||||
|
assert len(buf.new) == 0
|
||||||
|
|
||||||
|
def test_insert_deduplicates_committed(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
# Commit "hello"
|
||||||
|
tokens1 = make_tokens(["hello", "world"])
|
||||||
|
buf.insert(tokens1, offset=0.0)
|
||||||
|
buf.flush() # commits "hello" (buffer was empty, so nothing matches)
|
||||||
|
# Actually with empty buffer, flush won't commit anything
|
||||||
|
# Let's do it properly: two rounds
|
||||||
|
buf2 = HypothesisBuffer()
|
||||||
|
first = make_tokens(["hello", "world"])
|
||||||
|
buf2.insert(first, offset=0.0)
|
||||||
|
buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"]
|
||||||
|
|
||||||
|
second = make_tokens(["hello", "world", "test"])
|
||||||
|
buf2.insert(second, offset=0.0)
|
||||||
|
committed = buf2.flush()
|
||||||
|
# LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"]
|
||||||
|
assert len(committed) == 2
|
||||||
|
assert committed[0].text == "hello"
|
||||||
|
assert committed[1].text == "world"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlush:
|
||||||
|
def test_flush_empty(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
committed = buf.flush()
|
||||||
|
assert committed == []
|
||||||
|
|
||||||
|
def test_flush_lcp_matching(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
# Round 1: establish buffer
|
||||||
|
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||||
|
buf.flush() # buffer = ["hello", "world"], committed = []
|
||||||
|
|
||||||
|
# Round 2: same prefix, new suffix
|
||||||
|
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||||
|
committed = buf.flush()
|
||||||
|
assert [t.text for t in committed] == ["hello", "world"]
|
||||||
|
|
||||||
|
def test_flush_no_match(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
# Round 1
|
||||||
|
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||||
|
buf.flush()
|
||||||
|
|
||||||
|
# Round 2: completely different
|
||||||
|
buf.insert(make_tokens(["foo", "bar"]), offset=0.0)
|
||||||
|
committed = buf.flush()
|
||||||
|
assert committed == []
|
||||||
|
|
||||||
|
def test_flush_partial_match(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||||
|
buf.flush()
|
||||||
|
|
||||||
|
buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0)
|
||||||
|
committed = buf.flush()
|
||||||
|
assert len(committed) == 1
|
||||||
|
assert committed[0].text == "hello"
|
||||||
|
|
||||||
|
def test_flush_updates_last_committed(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||||
|
buf.flush()
|
||||||
|
|
||||||
|
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||||
|
buf.flush()
|
||||||
|
assert buf.last_committed_word == "world"
|
||||||
|
assert buf.last_committed_time > 0
|
||||||
|
|
||||||
|
def test_flush_with_confidence_validation(self):
|
||||||
|
buf = HypothesisBuffer(confidence_validation=True)
|
||||||
|
high_conf = [
|
||||||
|
ASRToken(start=0.0, end=0.5, text="sure", probability=0.99),
|
||||||
|
ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5),
|
||||||
|
]
|
||||||
|
buf.insert(high_conf, offset=0.0)
|
||||||
|
committed = buf.flush()
|
||||||
|
# "sure" has p>0.95 → committed immediately
|
||||||
|
assert len(committed) == 1
|
||||||
|
assert committed[0].text == "sure"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPopCommitted:
|
||||||
|
def test_pop_removes_old(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0)
|
||||||
|
# "a": end=1.0, "b": end=2.0, "c": end=3.0
|
||||||
|
# pop_committed removes tokens with end <= time
|
||||||
|
buf.pop_committed(2.0)
|
||||||
|
# "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains
|
||||||
|
assert len(buf.committed_in_buffer) == 1
|
||||||
|
assert buf.committed_in_buffer[0].text == "c"
|
||||||
|
|
||||||
|
def test_pop_nothing(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0)
|
||||||
|
buf.pop_committed(0.0)
|
||||||
|
assert len(buf.committed_in_buffer) == 2
|
||||||
|
|
||||||
|
def test_pop_all(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5)
|
||||||
|
buf.pop_committed(100.0)
|
||||||
|
assert len(buf.committed_in_buffer) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingSimulation:
|
||||||
|
"""Multi-round insert/flush simulating real streaming behavior."""
|
||||||
|
|
||||||
|
def test_three_rounds(self):
|
||||||
|
buf = HypothesisBuffer()
|
||||||
|
all_committed = []
|
||||||
|
|
||||||
|
# Round 1: "this is"
|
||||||
|
buf.insert(make_tokens(["this", "is"]), offset=0.0)
|
||||||
|
all_committed.extend(buf.flush())
|
||||||
|
|
||||||
|
# Round 2: "this is a test"
|
||||||
|
buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0)
|
||||||
|
all_committed.extend(buf.flush())
|
||||||
|
|
||||||
|
# Round 3: "this is a test today"
|
||||||
|
buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0)
|
||||||
|
all_committed.extend(buf.flush())
|
||||||
|
|
||||||
|
words = [t.text for t in all_committed]
|
||||||
|
assert "this" in words
|
||||||
|
assert "is" in words
|
||||||
|
assert "a" in words
|
||||||
|
assert "test" in words
|
||||||
183
tests/test_metrics.py
Normal file
183
tests/test_metrics.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeText:
|
||||||
|
def test_lowercase(self):
|
||||||
|
assert normalize_text("Hello World") == "hello world"
|
||||||
|
|
||||||
|
def test_strip_punctuation(self):
|
||||||
|
assert normalize_text("Hello, world!") == "hello world"
|
||||||
|
|
||||||
|
def test_collapse_whitespace(self):
|
||||||
|
assert normalize_text(" hello world ") == "hello world"
|
||||||
|
|
||||||
|
def test_keep_hyphens(self):
|
||||||
|
assert normalize_text("real-time") == "real-time"
|
||||||
|
|
||||||
|
def test_keep_apostrophes(self):
|
||||||
|
assert normalize_text("don't") == "don't"
|
||||||
|
|
||||||
|
def test_unicode_normalized(self):
|
||||||
|
# e + combining accent should be same as precomposed
|
||||||
|
assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
assert normalize_text("") == ""
|
||||||
|
|
||||||
|
def test_only_punctuation(self):
|
||||||
|
assert normalize_text("...!?") == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeWER:
|
||||||
|
def test_perfect_match(self):
|
||||||
|
result = compute_wer("hello world", "hello world")
|
||||||
|
assert result["wer"] == 0.0
|
||||||
|
assert result["substitutions"] == 0
|
||||||
|
assert result["insertions"] == 0
|
||||||
|
assert result["deletions"] == 0
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
result = compute_wer("Hello World", "hello world")
|
||||||
|
assert result["wer"] == 0.0
|
||||||
|
|
||||||
|
def test_punctuation_ignored(self):
|
||||||
|
result = compute_wer("Hello, world!", "hello world")
|
||||||
|
assert result["wer"] == 0.0
|
||||||
|
|
||||||
|
def test_one_substitution(self):
|
||||||
|
result = compute_wer("hello world", "hello earth")
|
||||||
|
assert result["wer"] == pytest.approx(0.5)
|
||||||
|
assert result["substitutions"] == 1
|
||||||
|
|
||||||
|
def test_one_insertion(self):
|
||||||
|
result = compute_wer("hello world", "hello big world")
|
||||||
|
assert result["wer"] == pytest.approx(0.5)
|
||||||
|
assert result["insertions"] == 1
|
||||||
|
|
||||||
|
def test_one_deletion(self):
|
||||||
|
result = compute_wer("hello big world", "hello world")
|
||||||
|
assert result["wer"] == pytest.approx(1 / 3)
|
||||||
|
assert result["deletions"] == 1
|
||||||
|
|
||||||
|
def test_completely_different(self):
|
||||||
|
result = compute_wer("the cat sat", "a dog ran")
|
||||||
|
assert result["wer"] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_empty_reference(self):
|
||||||
|
result = compute_wer("", "hello")
|
||||||
|
assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
|
||||||
|
assert result["ref_words"] == 0
|
||||||
|
|
||||||
|
def test_empty_hypothesis(self):
|
||||||
|
result = compute_wer("hello world", "")
|
||||||
|
assert result["wer"] == pytest.approx(1.0)
|
||||||
|
assert result["deletions"] == 2
|
||||||
|
|
||||||
|
def test_both_empty(self):
|
||||||
|
result = compute_wer("", "")
|
||||||
|
assert result["wer"] == 0.0
|
||||||
|
|
||||||
|
def test_ref_and_hyp_word_counts(self):
|
||||||
|
result = compute_wer("one two three", "one two three four")
|
||||||
|
assert result["ref_words"] == 3
|
||||||
|
assert result["hyp_words"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeTimestampAccuracy:
|
||||||
|
def test_perfect_match(self):
|
||||||
|
words = [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0},
|
||||||
|
]
|
||||||
|
result = compute_timestamp_accuracy(words, words)
|
||||||
|
assert result["mae_start"] == 0.0
|
||||||
|
assert result["max_delta_start"] == 0.0
|
||||||
|
assert result["n_matched"] == 2
|
||||||
|
|
||||||
|
def test_constant_offset(self):
|
||||||
|
ref = [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 0.5, "end": 1.0},
|
||||||
|
]
|
||||||
|
pred = [
|
||||||
|
{"word": "hello", "start": 0.1, "end": 0.6},
|
||||||
|
{"word": "world", "start": 0.6, "end": 1.1},
|
||||||
|
]
|
||||||
|
result = compute_timestamp_accuracy(pred, ref)
|
||||||
|
assert result["mae_start"] == pytest.approx(0.1)
|
||||||
|
assert result["max_delta_start"] == pytest.approx(0.1)
|
||||||
|
assert result["n_matched"] == 2
|
||||||
|
|
||||||
|
def test_mismatched_word_counts(self):
|
||||||
|
ref = [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "beautiful", "start": 0.5, "end": 1.0},
|
||||||
|
{"word": "world", "start": 1.0, "end": 1.5},
|
||||||
|
]
|
||||||
|
pred = [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": "world", "start": 1.1, "end": 1.6},
|
||||||
|
]
|
||||||
|
result = compute_timestamp_accuracy(pred, ref)
|
||||||
|
assert result["n_matched"] == 2
|
||||||
|
assert result["n_ref"] == 3
|
||||||
|
assert result["n_pred"] == 2
|
||||||
|
|
||||||
|
def test_empty_predicted(self):
|
||||||
|
ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
||||||
|
result = compute_timestamp_accuracy([], ref)
|
||||||
|
assert result["mae_start"] is None
|
||||||
|
assert result["n_matched"] == 0
|
||||||
|
|
||||||
|
def test_empty_reference(self):
|
||||||
|
pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
||||||
|
result = compute_timestamp_accuracy(pred, [])
|
||||||
|
assert result["mae_start"] is None
|
||||||
|
assert result["n_matched"] == 0
|
||||||
|
|
||||||
|
def test_case_insensitive_matching(self):
|
||||||
|
ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
|
||||||
|
pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
|
||||||
|
result = compute_timestamp_accuracy(pred, ref)
|
||||||
|
assert result["n_matched"] == 1
|
||||||
|
assert result["mae_start"] == pytest.approx(0.1)
|
||||||
|
|
||||||
|
def test_median_even_count(self):
|
||||||
|
"""Median with even number of matched words should average the two middle values."""
|
||||||
|
ref = [
|
||||||
|
{"word": "a", "start": 0.0, "end": 0.2},
|
||||||
|
{"word": "b", "start": 0.5, "end": 0.7},
|
||||||
|
{"word": "c", "start": 1.0, "end": 1.2},
|
||||||
|
{"word": "d", "start": 1.5, "end": 1.7},
|
||||||
|
]
|
||||||
|
pred = [
|
||||||
|
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
||||||
|
{"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
|
||||||
|
{"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
|
||||||
|
{"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
|
||||||
|
]
|
||||||
|
result = compute_timestamp_accuracy(pred, ref)
|
||||||
|
assert result["n_matched"] == 4
|
||||||
|
# sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
|
||||||
|
assert result["median_delta_start"] == pytest.approx(0.25)
|
||||||
|
|
||||||
|
def test_median_odd_count(self):
|
||||||
|
"""Median with odd number of matched words takes the middle value."""
|
||||||
|
ref = [
|
||||||
|
{"word": "a", "start": 0.0, "end": 0.2},
|
||||||
|
{"word": "b", "start": 0.5, "end": 0.7},
|
||||||
|
{"word": "c", "start": 1.0, "end": 1.2},
|
||||||
|
]
|
||||||
|
pred = [
|
||||||
|
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
||||||
|
{"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
|
||||||
|
{"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
|
||||||
|
]
|
||||||
|
result = compute_timestamp_accuracy(pred, ref)
|
||||||
|
assert result["n_matched"] == 3
|
||||||
|
# sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
|
||||||
|
assert result["median_delta_start"] == pytest.approx(0.2)
|
||||||
99
tests/test_silence_handling.py
Normal file
99
tests/test_silence_handling.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Tests for silence handling — state machine and double-counting regression."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import Silence
|
||||||
|
|
||||||
|
|
||||||
|
class TestSilenceStateMachine:
|
||||||
|
"""Test Silence object state transitions."""
|
||||||
|
|
||||||
|
def test_initial_state(self):
|
||||||
|
s = Silence(start=1.0, is_starting=True)
|
||||||
|
assert s.is_starting is True
|
||||||
|
assert s.has_ended is False
|
||||||
|
assert s.duration is None
|
||||||
|
assert s.end is None
|
||||||
|
|
||||||
|
def test_end_silence(self):
|
||||||
|
s = Silence(start=1.0, is_starting=True)
|
||||||
|
s.end = 3.0
|
||||||
|
s.is_starting = False
|
||||||
|
s.has_ended = True
|
||||||
|
s.compute_duration()
|
||||||
|
assert s.duration == pytest.approx(2.0)
|
||||||
|
|
||||||
|
def test_very_short_silence(self):
|
||||||
|
s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True)
|
||||||
|
s.compute_duration()
|
||||||
|
assert s.duration == pytest.approx(0.01)
|
||||||
|
|
||||||
|
def test_zero_duration_silence(self):
|
||||||
|
s = Silence(start=5.0, end=5.0)
|
||||||
|
s.compute_duration()
|
||||||
|
assert s.duration == pytest.approx(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSilenceDoubleCounting:
|
||||||
|
"""Regression tests for the silence double-counting bug.
|
||||||
|
|
||||||
|
The bug: _begin_silence and _end_silence both pushed self.current_silence
|
||||||
|
to the queue. Since they were the same Python object, _end_silence's mutation
|
||||||
|
affected the already-queued start event. The consumer processed both as
|
||||||
|
ended silences, doubling the duration.
|
||||||
|
|
||||||
|
Fix: _begin_silence now pushes a separate Silence object for the start event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_start_and_end_are_separate_objects(self):
|
||||||
|
"""Simulate the fix: start event and end event must be different objects."""
|
||||||
|
# Simulate _begin_silence: creates start event as separate object
|
||||||
|
current_silence = Silence(start=1.0, is_starting=True)
|
||||||
|
start_event = Silence(start=1.0, is_starting=True) # separate copy
|
||||||
|
|
||||||
|
# Simulate _end_silence: mutates current_silence
|
||||||
|
current_silence.end = 3.0
|
||||||
|
current_silence.is_starting = False
|
||||||
|
current_silence.has_ended = True
|
||||||
|
current_silence.compute_duration()
|
||||||
|
|
||||||
|
# start_event should NOT be affected by mutations to current_silence
|
||||||
|
assert start_event.is_starting is True
|
||||||
|
assert start_event.has_ended is False
|
||||||
|
assert start_event.end is None
|
||||||
|
|
||||||
|
# current_silence (end event) has the final state
|
||||||
|
assert current_silence.has_ended is True
|
||||||
|
assert current_silence.duration == pytest.approx(2.0)
|
||||||
|
|
||||||
|
def test_single_object_would_cause_double_counting(self):
|
||||||
|
"""Demonstrate the bug: if same object is used for both events."""
|
||||||
|
shared = Silence(start=1.0, is_starting=True)
|
||||||
|
queue = [shared] # start event queued
|
||||||
|
|
||||||
|
# Mutate (simulates _end_silence)
|
||||||
|
shared.end = 3.0
|
||||||
|
shared.is_starting = False
|
||||||
|
shared.has_ended = True
|
||||||
|
shared.compute_duration()
|
||||||
|
queue.append(shared) # end event queued
|
||||||
|
|
||||||
|
# Both queue items point to the SAME mutated object
|
||||||
|
assert queue[0] is queue[1] # same reference
|
||||||
|
assert queue[0].has_ended is True # start event also shows ended!
|
||||||
|
|
||||||
|
# This would cause double-counting: both items have has_ended=True
|
||||||
|
# and duration=2.0, so the consumer adds 2.0 twice = 4.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsecutiveSilences:
|
||||||
|
def test_multiple_silences(self):
|
||||||
|
"""Multiple silence periods should have independent durations."""
|
||||||
|
s1 = Silence(start=1.0, end=2.0)
|
||||||
|
s1.compute_duration()
|
||||||
|
s2 = Silence(start=5.0, end=8.0)
|
||||||
|
s2.compute_duration()
|
||||||
|
assert s1.duration == pytest.approx(1.0)
|
||||||
|
assert s2.duration == pytest.approx(3.0)
|
||||||
|
# Total silence should be sum, not accumulated on single object
|
||||||
|
assert s1.duration + s2.duration == pytest.approx(4.0)
|
||||||
185
tests/test_timed_objects.py
Normal file
185
tests/test_timed_objects.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Tests for whisperlivekit.timed_objects data classes."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import (
|
||||||
|
ASRToken,
|
||||||
|
FrontData,
|
||||||
|
Segment,
|
||||||
|
Silence,
|
||||||
|
TimedText,
|
||||||
|
Transcript,
|
||||||
|
format_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatTime:
|
||||||
|
def test_zero(self):
|
||||||
|
assert format_time(0) == "0:00:00"
|
||||||
|
|
||||||
|
def test_one_minute(self):
|
||||||
|
assert format_time(60) == "0:01:00"
|
||||||
|
|
||||||
|
def test_one_hour(self):
|
||||||
|
assert format_time(3600) == "1:00:00"
|
||||||
|
|
||||||
|
def test_fractional_truncated(self):
|
||||||
|
assert format_time(61.9) == "0:01:01"
|
||||||
|
|
||||||
|
|
||||||
|
class TestASRToken:
|
||||||
|
def test_with_offset(self):
|
||||||
|
t = ASRToken(start=1.0, end=2.0, text="hello")
|
||||||
|
shifted = t.with_offset(0.5)
|
||||||
|
assert shifted.start == pytest.approx(1.5)
|
||||||
|
assert shifted.end == pytest.approx(2.5)
|
||||||
|
assert shifted.text == "hello"
|
||||||
|
|
||||||
|
def test_with_offset_preserves_fields(self):
|
||||||
|
t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95)
|
||||||
|
shifted = t.with_offset(1.0)
|
||||||
|
assert shifted.speaker == 2
|
||||||
|
assert shifted.probability == 0.95
|
||||||
|
|
||||||
|
def test_is_silence_false(self):
|
||||||
|
t = ASRToken(start=0.0, end=1.0, text="hello")
|
||||||
|
assert t.is_silence() is False
|
||||||
|
|
||||||
|
def test_bool_truthy(self):
|
||||||
|
t = ASRToken(start=0.0, end=1.0, text="hello")
|
||||||
|
assert bool(t) is True
|
||||||
|
|
||||||
|
def test_bool_falsy(self):
|
||||||
|
t = ASRToken(start=0.0, end=1.0, text="")
|
||||||
|
assert bool(t) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimedText:
|
||||||
|
def test_has_punctuation_period(self):
|
||||||
|
t = TimedText(text="hello.")
|
||||||
|
assert t.has_punctuation() is True
|
||||||
|
|
||||||
|
def test_has_punctuation_exclamation(self):
|
||||||
|
t = TimedText(text="wow!")
|
||||||
|
assert t.has_punctuation() is True
|
||||||
|
|
||||||
|
def test_has_punctuation_question(self):
|
||||||
|
t = TimedText(text="really?")
|
||||||
|
assert t.has_punctuation() is True
|
||||||
|
|
||||||
|
def test_has_punctuation_cjk(self):
|
||||||
|
t = TimedText(text="hello。")
|
||||||
|
assert t.has_punctuation() is True
|
||||||
|
|
||||||
|
def test_no_punctuation(self):
|
||||||
|
t = TimedText(text="hello world")
|
||||||
|
assert t.has_punctuation() is False
|
||||||
|
|
||||||
|
def test_duration(self):
|
||||||
|
t = TimedText(start=1.0, end=3.5)
|
||||||
|
assert t.duration() == pytest.approx(2.5)
|
||||||
|
|
||||||
|
def test_contains_timespan(self):
|
||||||
|
outer = TimedText(start=0.0, end=5.0)
|
||||||
|
inner = TimedText(start=1.0, end=3.0)
|
||||||
|
assert outer.contains_timespan(inner) is True
|
||||||
|
assert inner.contains_timespan(outer) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSilence:
|
||||||
|
def test_compute_duration(self):
|
||||||
|
s = Silence(start=1.0, end=3.5)
|
||||||
|
d = s.compute_duration()
|
||||||
|
assert d == pytest.approx(2.5)
|
||||||
|
assert s.duration == pytest.approx(2.5)
|
||||||
|
|
||||||
|
def test_compute_duration_none_start(self):
|
||||||
|
s = Silence(start=None, end=3.5)
|
||||||
|
d = s.compute_duration()
|
||||||
|
assert d is None
|
||||||
|
|
||||||
|
def test_compute_duration_none_end(self):
|
||||||
|
s = Silence(start=1.0, end=None)
|
||||||
|
d = s.compute_duration()
|
||||||
|
assert d is None
|
||||||
|
|
||||||
|
def test_is_silence_true(self):
|
||||||
|
s = Silence()
|
||||||
|
assert s.is_silence() is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscript:
|
||||||
|
def test_from_tokens(self, sample_tokens):
|
||||||
|
t = Transcript.from_tokens(sample_tokens, sep="")
|
||||||
|
assert t.text == "Hello world test."
|
||||||
|
assert t.start == pytest.approx(0.0)
|
||||||
|
assert t.end == pytest.approx(1.5)
|
||||||
|
|
||||||
|
def test_from_tokens_with_sep(self, sample_tokens):
|
||||||
|
t = Transcript.from_tokens(sample_tokens, sep="|")
|
||||||
|
assert t.text == "Hello| world| test."
|
||||||
|
|
||||||
|
def test_from_empty_tokens(self):
|
||||||
|
t = Transcript.from_tokens([])
|
||||||
|
assert t.text == ""
|
||||||
|
assert t.start is None
|
||||||
|
assert t.end is None
|
||||||
|
|
||||||
|
def test_from_tokens_with_offset(self, sample_tokens):
|
||||||
|
t = Transcript.from_tokens(sample_tokens, offset=10.0)
|
||||||
|
assert t.start == pytest.approx(10.0)
|
||||||
|
assert t.end == pytest.approx(11.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegment:
|
||||||
|
def test_from_tokens(self, sample_tokens):
|
||||||
|
seg = Segment.from_tokens(sample_tokens)
|
||||||
|
assert seg is not None
|
||||||
|
assert seg.text == "Hello world test."
|
||||||
|
assert seg.start == pytest.approx(0.0)
|
||||||
|
assert seg.end == pytest.approx(1.5)
|
||||||
|
assert seg.speaker == -1
|
||||||
|
|
||||||
|
def test_from_silence_tokens(self):
|
||||||
|
silences = [
|
||||||
|
Silence(start=1.0, end=2.0),
|
||||||
|
Silence(start=2.0, end=3.0),
|
||||||
|
]
|
||||||
|
seg = Segment.from_tokens(silences, is_silence=True)
|
||||||
|
assert seg is not None
|
||||||
|
assert seg.speaker == -2
|
||||||
|
assert seg.is_silence() is True
|
||||||
|
assert seg.text is None
|
||||||
|
|
||||||
|
def test_from_empty_tokens(self):
|
||||||
|
seg = Segment.from_tokens([])
|
||||||
|
assert seg is None
|
||||||
|
|
||||||
|
def test_to_dict(self, sample_tokens):
|
||||||
|
seg = Segment.from_tokens(sample_tokens)
|
||||||
|
d = seg.to_dict()
|
||||||
|
assert "text" in d
|
||||||
|
assert "speaker" in d
|
||||||
|
assert "start" in d
|
||||||
|
assert "end" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrontData:
|
||||||
|
def test_to_dict_empty(self):
|
||||||
|
fd = FrontData()
|
||||||
|
d = fd.to_dict()
|
||||||
|
assert d["lines"] == []
|
||||||
|
assert d["buffer_transcription"] == ""
|
||||||
|
assert "error" not in d
|
||||||
|
|
||||||
|
def test_to_dict_with_error(self):
|
||||||
|
fd = FrontData(error="something broke")
|
||||||
|
d = fd.to_dict()
|
||||||
|
assert d["error"] == "something broke"
|
||||||
|
|
||||||
|
def test_to_dict_with_lines(self, sample_tokens):
|
||||||
|
seg = Segment.from_tokens(sample_tokens)
|
||||||
|
fd = FrontData(lines=[seg])
|
||||||
|
d = fd.to_dict()
|
||||||
|
assert len(d["lines"]) == 1
|
||||||
|
assert d["lines"][0]["text"] == "Hello world test."
|
||||||
6575
uv.lock
generated
Normal file
6575
uv.lock
generated
Normal file
File diff suppressed because one or more lines are too long
@@ -9,10 +9,11 @@ import numpy as np
|
|||||||
from whisperlivekit.core import (TranscriptionEngine,
|
from whisperlivekit.core import (TranscriptionEngine,
|
||||||
online_diarization_factory, online_factory,
|
online_diarization_factory, online_factory,
|
||||||
online_translation_factory)
|
online_translation_factory)
|
||||||
|
from whisperlivekit.metrics_collector import SessionMetrics
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
|
||||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||||
Line, Silence, State, Transcript)
|
Segment, Silence, State, Transcript)
|
||||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
@@ -66,7 +67,8 @@ class AudioProcessor:
|
|||||||
self.args = models.args
|
self.args = models.args
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
self.channels = 1
|
self.channels = 1
|
||||||
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
|
chunk_seconds = self.args.vac_chunk_size if self.args.vac else self.args.min_chunk_size
|
||||||
|
self.samples_per_sec = int(self.sample_rate * chunk_seconds)
|
||||||
self.bytes_per_sample = 2
|
self.bytes_per_sample = 2
|
||||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||||
@@ -85,12 +87,14 @@ class AudioProcessor:
|
|||||||
|
|
||||||
# Models and processing
|
# Models and processing
|
||||||
self.asr: Any = models.asr
|
self.asr: Any = models.asr
|
||||||
self.vac_model: Any = models.vac_model
|
|
||||||
if self.args.vac:
|
|
||||||
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
|
||||||
else:
|
|
||||||
self.vac: Optional[FixedVADIterator] = None
|
self.vac: Optional[FixedVADIterator] = None
|
||||||
|
|
||||||
|
if self.args.vac:
|
||||||
|
if models.vac_session is not None:
|
||||||
|
vac_model = OnnxWrapper(session=models.vac_session)
|
||||||
|
self.vac = FixedVADIterator(vac_model)
|
||||||
|
else:
|
||||||
|
self.vac = FixedVADIterator(load_jit_vad())
|
||||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||||
self._ffmpeg_error: Optional[str] = None
|
self._ffmpeg_error: Optional[str] = None
|
||||||
@@ -115,6 +119,7 @@ class AudioProcessor:
|
|||||||
self.translation_task: Optional[asyncio.Task] = None
|
self.translation_task: Optional[asyncio.Task] = None
|
||||||
self.watchdog_task: Optional[asyncio.Task] = None
|
self.watchdog_task: Optional[asyncio.Task] = None
|
||||||
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
||||||
|
self.metrics: SessionMetrics = SessionMetrics()
|
||||||
|
|
||||||
self.transcription: Optional[Any] = None
|
self.transcription: Optional[Any] = None
|
||||||
self.translation: Optional[Any] = None
|
self.translation: Optional[Any] = None
|
||||||
@@ -136,25 +141,43 @@ class AudioProcessor:
|
|||||||
if self.translation_queue:
|
if self.translation_queue:
|
||||||
await self.translation_queue.put(self.current_silence)
|
await self.translation_queue.put(self.current_silence)
|
||||||
|
|
||||||
async def _begin_silence(self) -> None:
|
async def _begin_silence(self, at_sample: Optional[int] = None) -> None:
|
||||||
if self.current_silence:
|
if self.current_silence:
|
||||||
return
|
return
|
||||||
now = time() - self.beg_loop
|
# Use audio stream time (sample-precise) for accurate silence duration
|
||||||
|
if at_sample is not None:
|
||||||
|
audio_t = at_sample / self.sample_rate
|
||||||
|
else:
|
||||||
|
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
|
||||||
self.current_silence = Silence(
|
self.current_silence = Silence(
|
||||||
is_starting=True, start=now
|
is_starting=True, start=audio_t
|
||||||
)
|
)
|
||||||
await self._push_silence_event()
|
# Push a separate start-only event so _end_silence won't mutate it
|
||||||
|
start_event = Silence(is_starting=True, start=audio_t)
|
||||||
|
if self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(start_event)
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(start_event)
|
||||||
|
if self.translation_queue:
|
||||||
|
await self.translation_queue.put(start_event)
|
||||||
|
|
||||||
async def _end_silence(self) -> None:
|
async def _end_silence(self, at_sample: Optional[int] = None) -> None:
|
||||||
if not self.current_silence:
|
if not self.current_silence:
|
||||||
return
|
return
|
||||||
now = time() - self.beg_loop
|
if at_sample is not None:
|
||||||
self.current_silence.end = now
|
audio_t = at_sample / self.sample_rate
|
||||||
self.current_silence.is_starting=False
|
else:
|
||||||
self.current_silence.has_ended=True
|
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
|
||||||
|
self.current_silence.end = audio_t
|
||||||
|
self.current_silence.is_starting = False
|
||||||
|
self.current_silence.has_ended = True
|
||||||
self.current_silence.compute_duration()
|
self.current_silence.compute_duration()
|
||||||
|
self.metrics.n_silence_events += 1
|
||||||
|
if self.current_silence.duration is not None:
|
||||||
|
self.metrics.total_silence_duration_s += self.current_silence.duration
|
||||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||||
self.state.new_tokens.append(self.current_silence)
|
self.state.new_tokens.append(self.current_silence)
|
||||||
|
# Push the completed silence as the end event (separate from the start event)
|
||||||
await self._push_silence_event()
|
await self._push_silence_event()
|
||||||
self.current_silence = None
|
self.current_silence = None
|
||||||
|
|
||||||
@@ -250,6 +273,34 @@ class AudioProcessor:
|
|||||||
if self.translation:
|
if self.translation:
|
||||||
await self.translation_queue.put(SENTINEL)
|
await self.translation_queue.put(SENTINEL)
|
||||||
|
|
||||||
|
async def _finish_transcription(self) -> None:
|
||||||
|
"""Call finish() on the online processor to flush remaining tokens."""
|
||||||
|
if not self.transcription:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if hasattr(self.transcription, 'finish'):
|
||||||
|
final_tokens, end_time = await asyncio.to_thread(self.transcription.finish)
|
||||||
|
else:
|
||||||
|
# SimulStreamingOnlineProcessor uses start_silence() → process_iter(is_last=True)
|
||||||
|
final_tokens, end_time = await asyncio.to_thread(self.transcription.start_silence)
|
||||||
|
|
||||||
|
final_tokens = final_tokens or []
|
||||||
|
if final_tokens:
|
||||||
|
logger.info(f"Finish flushed {len(final_tokens)} tokens")
|
||||||
|
_buffer_transcript = self.transcription.get_buffer()
|
||||||
|
async with self.lock:
|
||||||
|
self.state.tokens.extend(final_tokens)
|
||||||
|
self.state.buffer_transcription = _buffer_transcript
|
||||||
|
self.state.end_buffer = max(self.state.end_buffer, end_time)
|
||||||
|
self.state.new_tokens.extend(final_tokens)
|
||||||
|
self.state.new_tokens_buffer = _buffer_transcript
|
||||||
|
if self.translation_queue:
|
||||||
|
for token in final_tokens:
|
||||||
|
await self.translation_queue.put(token)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error finishing transcription: {e}")
|
||||||
|
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
async def transcription_processor(self) -> None:
|
async def transcription_processor(self) -> None:
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
@@ -260,6 +311,7 @@ class AudioProcessor:
|
|||||||
item = await get_all_from_queue(self.transcription_queue)
|
item = await get_all_from_queue(self.transcription_queue)
|
||||||
if item is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||||
|
await self._finish_transcription()
|
||||||
break
|
break
|
||||||
|
|
||||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||||
@@ -294,8 +346,13 @@ class AudioProcessor:
|
|||||||
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
|
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
|
||||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
|
_t0 = time()
|
||||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||||
|
_dur = time() - _t0
|
||||||
|
self.metrics.transcription_durations.append(_dur)
|
||||||
|
self.metrics.n_transcription_calls += 1
|
||||||
new_tokens = new_tokens or []
|
new_tokens = new_tokens or []
|
||||||
|
self.metrics.n_tokens_produced += len(new_tokens)
|
||||||
|
|
||||||
_buffer_transcript = self.transcription.get_buffer()
|
_buffer_transcript = self.transcription.get_buffer()
|
||||||
buffer_text = _buffer_transcript.text
|
buffer_text = _buffer_transcript.text
|
||||||
@@ -351,11 +408,14 @@ class AudioProcessor:
|
|||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
self.diarization.insert_silence(item.duration)
|
self.diarization.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.diarization.insert_audio_chunk(item)
|
self.diarization.insert_audio_chunk(item)
|
||||||
diarization_segments = await self.diarization.diarize()
|
diarization_segments = await self.diarization.diarize()
|
||||||
|
diar_end = 0.0
|
||||||
|
if diarization_segments:
|
||||||
|
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
||||||
|
async with self.lock:
|
||||||
self.state.new_diarization = diarization_segments
|
self.state.new_diarization = diarization_segments
|
||||||
|
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in diarization_processor: {e}")
|
logger.warning(f"Exception in diarization_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
@@ -427,6 +487,7 @@ class AudioProcessor:
|
|||||||
|
|
||||||
should_push = (response != self.last_response_content)
|
should_push = (response != self.last_response_content)
|
||||||
if should_push:
|
if should_push:
|
||||||
|
self.metrics.n_responses_sent += 1
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = response
|
self.last_response_content = response
|
||||||
|
|
||||||
@@ -529,6 +590,10 @@ class AudioProcessor:
|
|||||||
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
||||||
if self.diarization:
|
if self.diarization:
|
||||||
self.diarization.close()
|
self.diarization.close()
|
||||||
|
|
||||||
|
# Finalize session metrics
|
||||||
|
self.metrics.total_audio_duration_s = self.total_pcm_samples / self.sample_rate
|
||||||
|
self.metrics.log_summary()
|
||||||
logger.info("AudioProcessor cleanup complete.")
|
logger.info("AudioProcessor cleanup complete.")
|
||||||
|
|
||||||
def _processing_tasks_done(self) -> bool:
|
def _processing_tasks_done(self) -> bool:
|
||||||
@@ -547,6 +612,7 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if not self.beg_loop:
|
if not self.beg_loop:
|
||||||
self.beg_loop = time()
|
self.beg_loop = time()
|
||||||
|
self.metrics.session_start = self.beg_loop
|
||||||
self.current_silence = Silence(start=0.0, is_starting=True)
|
self.current_silence = Silence(start=0.0, is_starting=True)
|
||||||
self.tokens_alignment.beg_loop = self.beg_loop
|
self.tokens_alignment.beg_loop = self.beg_loop
|
||||||
|
|
||||||
@@ -554,6 +620,10 @@ class AudioProcessor:
|
|||||||
logger.info("Empty audio message received, initiating stop sequence.")
|
logger.info("Empty audio message received, initiating stop sequence.")
|
||||||
self.is_stopping = True
|
self.is_stopping = True
|
||||||
|
|
||||||
|
# Flush any remaining PCM data before signaling end-of-stream
|
||||||
|
if self.is_pcm_input and self.pcm_buffer:
|
||||||
|
await self._flush_remaining_pcm()
|
||||||
|
|
||||||
if self.transcription_queue:
|
if self.transcription_queue:
|
||||||
await self.transcription_queue.put(SENTINEL)
|
await self.transcription_queue.put(SENTINEL)
|
||||||
|
|
||||||
@@ -566,6 +636,8 @@ class AudioProcessor:
|
|||||||
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
|
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.metrics.n_chunks_received += 1
|
||||||
|
|
||||||
if self.is_pcm_input:
|
if self.is_pcm_input:
|
||||||
self.pcm_buffer.extend(message)
|
self.pcm_buffer.extend(message)
|
||||||
await self.handle_pcm_data()
|
await self.handle_pcm_data()
|
||||||
@@ -582,6 +654,11 @@ class AudioProcessor:
|
|||||||
logger.warning("Failed to write audio data to FFmpeg")
|
logger.warning("Failed to write audio data to FFmpeg")
|
||||||
|
|
||||||
async def handle_pcm_data(self) -> None:
|
async def handle_pcm_data(self) -> None:
|
||||||
|
# Without VAC, there's no speech detector to end the initial silence.
|
||||||
|
# Clear it on the first audio chunk so audio actually gets enqueued.
|
||||||
|
if not self.args.vac and self.current_silence:
|
||||||
|
await self._end_silence()
|
||||||
|
|
||||||
# Process when enough data
|
# Process when enough data
|
||||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||||
return
|
return
|
||||||
@@ -610,7 +687,7 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if res is not None:
|
if res is not None:
|
||||||
if "start" in res and self.current_silence:
|
if "start" in res and self.current_silence:
|
||||||
await self._end_silence()
|
await self._end_silence(at_sample=res.get("start"))
|
||||||
|
|
||||||
if "end" in res and not self.current_silence:
|
if "end" in res and not self.current_silence:
|
||||||
pre_silence_chunk = self._slice_before_silence(
|
pre_silence_chunk = self._slice_before_silence(
|
||||||
@@ -618,7 +695,7 @@ class AudioProcessor:
|
|||||||
)
|
)
|
||||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||||
await self._enqueue_active_audio(pre_silence_chunk)
|
await self._enqueue_active_audio(pre_silence_chunk)
|
||||||
await self._begin_silence()
|
await self._begin_silence(at_sample=res.get("end"))
|
||||||
|
|
||||||
if not self.current_silence:
|
if not self.current_silence:
|
||||||
await self._enqueue_active_audio(pcm_array)
|
await self._enqueue_active_audio(pcm_array)
|
||||||
@@ -627,3 +704,21 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if not self.args.transcription and not self.args.diarization:
|
if not self.args.transcription and not self.args.diarization:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def _flush_remaining_pcm(self) -> None:
|
||||||
|
"""Flush whatever PCM data remains in the buffer, regardless of size threshold."""
|
||||||
|
if not self.pcm_buffer:
|
||||||
|
return
|
||||||
|
aligned_size = (len(self.pcm_buffer) // self.bytes_per_sample) * self.bytes_per_sample
|
||||||
|
if aligned_size == 0:
|
||||||
|
return
|
||||||
|
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_size])
|
||||||
|
self.pcm_buffer = self.pcm_buffer[aligned_size:]
|
||||||
|
|
||||||
|
# End any active silence so the audio gets enqueued
|
||||||
|
if self.current_silence:
|
||||||
|
await self._end_silence(at_sample=self.total_pcm_samples)
|
||||||
|
|
||||||
|
await self._enqueue_active_audio(pcm_array)
|
||||||
|
self.total_pcm_samples += len(pcm_array)
|
||||||
|
logger.info(f"Flushed remaining PCM buffer: {len(pcm_array)} samples ({len(pcm_array)/self.sample_rate:.2f}s)")
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ def mlx_backend_available(warn_on_missing = False):
|
|||||||
return available
|
return available
|
||||||
|
|
||||||
|
|
||||||
|
def voxtral_hf_backend_available():
|
||||||
|
"""Return True if HF Transformers Voxtral backend is available."""
|
||||||
|
return module_available("transformers")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def faster_backend_available(warn_on_missing = False):
|
def faster_backend_available(warn_on_missing = False):
|
||||||
available = module_available("faster_whisper")
|
available = module_available("faster_whisper")
|
||||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||||
|
|||||||
@@ -14,15 +14,13 @@ logging.getLogger().setLevel(logging.WARNING)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
args = parse_args()
|
config = parse_args()
|
||||||
transcription_engine = None
|
transcription_engine = None
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
transcription_engine = TranscriptionEngine(
|
transcription_engine = TranscriptionEngine(config=config)
|
||||||
**vars(args),
|
|
||||||
)
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
@@ -63,7 +61,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
logger.info("WebSocket connection opened.")
|
logger.info("WebSocket connection opened.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to send config to client: {e}")
|
logger.warning(f"Failed to send config to client: {e}")
|
||||||
|
|
||||||
@@ -103,26 +101,26 @@ def main():
|
|||||||
|
|
||||||
uvicorn_kwargs = {
|
uvicorn_kwargs = {
|
||||||
"app": "whisperlivekit.basic_server:app",
|
"app": "whisperlivekit.basic_server:app",
|
||||||
"host":args.host,
|
"host": config.host,
|
||||||
"port":args.port,
|
"port": config.port,
|
||||||
"reload": False,
|
"reload": False,
|
||||||
"log_level": "info",
|
"log_level": "info",
|
||||||
"lifespan": "on",
|
"lifespan": "on",
|
||||||
}
|
}
|
||||||
|
|
||||||
ssl_kwargs = {}
|
ssl_kwargs = {}
|
||||||
if args.ssl_certfile or args.ssl_keyfile:
|
if config.ssl_certfile or config.ssl_keyfile:
|
||||||
if not (args.ssl_certfile and args.ssl_keyfile):
|
if not (config.ssl_certfile and config.ssl_keyfile):
|
||||||
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
||||||
ssl_kwargs = {
|
ssl_kwargs = {
|
||||||
"ssl_certfile": args.ssl_certfile,
|
"ssl_certfile": config.ssl_certfile,
|
||||||
"ssl_keyfile": args.ssl_keyfile
|
"ssl_keyfile": config.ssl_keyfile,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ssl_kwargs:
|
if ssl_kwargs:
|
||||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||||
if args.forwarded_allow_ips:
|
if config.forwarded_allow_ips:
|
||||||
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
|
uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips}
|
||||||
|
|
||||||
uvicorn.run(**uvicorn_kwargs)
|
uvicorn.run(**uvicorn_kwargs)
|
||||||
|
|
||||||
|
|||||||
102
whisperlivekit/config.py
Normal file
102
whisperlivekit/config.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Typed configuration for the WhisperLiveKit pipeline."""
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WhisperLiveKitConfig:
|
||||||
|
"""Single source of truth for all WhisperLiveKit configuration.
|
||||||
|
|
||||||
|
Replaces the previous dict-based parameter system in TranscriptionEngine.
|
||||||
|
All fields have defaults matching the prior behaviour.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Server / global
|
||||||
|
host: str = "localhost"
|
||||||
|
port: int = 8000
|
||||||
|
diarization: bool = False
|
||||||
|
punctuation_split: bool = False
|
||||||
|
target_language: str = ""
|
||||||
|
vac: bool = True
|
||||||
|
vac_chunk_size: float = 0.04
|
||||||
|
log_level: str = "DEBUG"
|
||||||
|
ssl_certfile: Optional[str] = None
|
||||||
|
ssl_keyfile: Optional[str] = None
|
||||||
|
forwarded_allow_ips: Optional[str] = None
|
||||||
|
transcription: bool = True
|
||||||
|
vad: bool = True
|
||||||
|
pcm_input: bool = False
|
||||||
|
disable_punctuation_split: bool = False
|
||||||
|
diarization_backend: str = "sortformer"
|
||||||
|
backend_policy: str = "simulstreaming"
|
||||||
|
backend: str = "auto"
|
||||||
|
|
||||||
|
# Transcription common
|
||||||
|
warmup_file: Optional[str] = None
|
||||||
|
min_chunk_size: float = 0.1
|
||||||
|
model_size: str = "base"
|
||||||
|
model_cache_dir: Optional[str] = None
|
||||||
|
model_dir: Optional[str] = None
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
lora_path: Optional[str] = None
|
||||||
|
lan: str = "auto"
|
||||||
|
direct_english_translation: bool = False
|
||||||
|
|
||||||
|
# LocalAgreement-specific
|
||||||
|
buffer_trimming: str = "segment"
|
||||||
|
confidence_validation: bool = False
|
||||||
|
buffer_trimming_sec: float = 15.0
|
||||||
|
|
||||||
|
# SimulStreaming-specific
|
||||||
|
disable_fast_encoder: bool = False
|
||||||
|
custom_alignment_heads: Optional[str] = None
|
||||||
|
frame_threshold: int = 25
|
||||||
|
beams: int = 1
|
||||||
|
decoder_type: Optional[str] = None
|
||||||
|
audio_max_len: float = 20.0
|
||||||
|
audio_min_len: float = 0.0
|
||||||
|
cif_ckpt_path: Optional[str] = None
|
||||||
|
never_fire: bool = False
|
||||||
|
init_prompt: Optional[str] = None
|
||||||
|
static_init_prompt: Optional[str] = None
|
||||||
|
max_context_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
# Diarization (diart)
|
||||||
|
segmentation_model: str = "pyannote/segmentation-3.0"
|
||||||
|
embedding_model: str = "pyannote/embedding"
|
||||||
|
|
||||||
|
# Translation
|
||||||
|
nllb_backend: str = "transformers"
|
||||||
|
nllb_size: str = "600M"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# .en model suffix forces English
|
||||||
|
if self.model_size and self.model_size.endswith(".en"):
|
||||||
|
self.lan = "en"
|
||||||
|
# Normalize backend_policy aliases
|
||||||
|
if self.backend_policy == "1":
|
||||||
|
self.backend_policy = "simulstreaming"
|
||||||
|
elif self.backend_policy == "2":
|
||||||
|
self.backend_policy = "localagreement"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Factory helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_namespace(cls, ns) -> "WhisperLiveKitConfig":
|
||||||
|
"""Create config from an argparse Namespace, ignoring unknown keys."""
|
||||||
|
known = {f.name for f in fields(cls)}
|
||||||
|
return cls(**{k: v for k, v in vars(ns).items() if k in known})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_kwargs(cls, **kwargs) -> "WhisperLiveKitConfig":
|
||||||
|
"""Create config from keyword arguments; warns on unknown keys."""
|
||||||
|
known = {f.name for f in fields(cls)}
|
||||||
|
unknown = set(kwargs.keys()) - known
|
||||||
|
if unknown:
|
||||||
|
logger.warning("Unknown config keys ignored: %s", unknown)
|
||||||
|
return cls(**{k: v for k, v in kwargs.items() if k in known})
|
||||||
@@ -1,132 +1,142 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from whisperlivekit.config import WhisperLiveKitConfig
|
||||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||||
|
|
||||||
|
|
||||||
def update_with_kwargs(_dict, kwargs):
|
|
||||||
_dict.update({
|
|
||||||
k: v for k, v in kwargs.items() if k in _dict
|
|
||||||
})
|
|
||||||
return _dict
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
_lock = threading.Lock() # Thread-safe singleton lock
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# Double-checked locking pattern for thread-safe singleton
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
# Check again inside lock to prevent race condition
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, config=None, **kwargs):
|
||||||
|
# Thread-safe initialization check
|
||||||
|
with TranscriptionEngine._lock:
|
||||||
if TranscriptionEngine._initialized:
|
if TranscriptionEngine._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
global_params = {
|
try:
|
||||||
"host": "localhost",
|
self._do_init(config, **kwargs)
|
||||||
"port": 8000,
|
except Exception:
|
||||||
"diarization": False,
|
# Reset singleton so a retry is possible
|
||||||
"punctuation_split": False,
|
with TranscriptionEngine._lock:
|
||||||
"target_language": "",
|
TranscriptionEngine._instance = None
|
||||||
"vac": True,
|
TranscriptionEngine._initialized = False
|
||||||
"vac_onnx": False,
|
raise
|
||||||
"vac_chunk_size": 0.04,
|
|
||||||
"log_level": "DEBUG",
|
|
||||||
"ssl_certfile": None,
|
|
||||||
"ssl_keyfile": None,
|
|
||||||
"forwarded_allow_ips": None,
|
|
||||||
"transcription": True,
|
|
||||||
"vad": True,
|
|
||||||
"pcm_input": False,
|
|
||||||
"disable_punctuation_split" : False,
|
|
||||||
"diarization_backend": "sortformer",
|
|
||||||
"backend_policy": "simulstreaming",
|
|
||||||
"backend": "auto",
|
|
||||||
}
|
|
||||||
global_params = update_with_kwargs(global_params, kwargs)
|
|
||||||
|
|
||||||
transcription_common_params = {
|
with TranscriptionEngine._lock:
|
||||||
"warmup_file": None,
|
TranscriptionEngine._initialized = True
|
||||||
"min_chunk_size": 0.1,
|
|
||||||
"model_size": "base",
|
|
||||||
"model_cache_dir": None,
|
|
||||||
"model_dir": None,
|
|
||||||
"model_path": None,
|
|
||||||
"lan": "auto",
|
|
||||||
"direct_english_translation": False,
|
|
||||||
}
|
|
||||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
|
||||||
|
|
||||||
if transcription_common_params['model_size'].endswith(".en"):
|
def _do_init(self, config=None, **kwargs):
|
||||||
transcription_common_params["lan"] = "en"
|
# Handle negated kwargs from programmatic API
|
||||||
if 'no_transcription' in kwargs:
|
if 'no_transcription' in kwargs:
|
||||||
global_params['transcription'] = not global_params['no_transcription']
|
kwargs['transcription'] = not kwargs.pop('no_transcription')
|
||||||
if 'no_vad' in kwargs:
|
if 'no_vad' in kwargs:
|
||||||
global_params['vad'] = not kwargs['no_vad']
|
kwargs['vad'] = not kwargs.pop('no_vad')
|
||||||
if 'no_vac' in kwargs:
|
if 'no_vac' in kwargs:
|
||||||
global_params['vac'] = not kwargs['no_vac']
|
kwargs['vac'] = not kwargs.pop('no_vac')
|
||||||
|
|
||||||
self.args = Namespace(**{**global_params, **transcription_common_params})
|
if config is None:
|
||||||
|
if isinstance(kwargs.get('config'), WhisperLiveKitConfig):
|
||||||
|
config = kwargs.pop('config')
|
||||||
|
else:
|
||||||
|
config = WhisperLiveKitConfig.from_kwargs(**kwargs)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Backward compat: expose as self.args (Namespace-like) for AudioProcessor etc.
|
||||||
|
self.args = Namespace(**asdict(config))
|
||||||
|
|
||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.diarization = None
|
self.diarization = None
|
||||||
self.vac_model = None
|
self.vac_session = None
|
||||||
|
|
||||||
if self.args.vac:
|
if config.vac:
|
||||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
from whisperlivekit.silero_vad_iterator import is_onnx_available
|
||||||
|
|
||||||
# Use ONNX if specified, otherwise use JIT (default)
|
if is_onnx_available():
|
||||||
use_onnx = kwargs.get('vac_onnx', False)
|
from whisperlivekit.silero_vad_iterator import load_onnx_session
|
||||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
self.vac_session = load_onnx_session()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
|
||||||
|
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
|
||||||
|
)
|
||||||
|
|
||||||
backend_policy = self.args.backend_policy
|
transcription_common_params = {
|
||||||
if self.args.transcription:
|
"warmup_file": config.warmup_file,
|
||||||
if backend_policy == "simulstreaming":
|
"min_chunk_size": config.min_chunk_size,
|
||||||
simulstreaming_params = {
|
"model_size": config.model_size,
|
||||||
"disable_fast_encoder": False,
|
"model_cache_dir": config.model_cache_dir,
|
||||||
"custom_alignment_heads": None,
|
"model_dir": config.model_dir,
|
||||||
"frame_threshold": 25,
|
"model_path": config.model_path,
|
||||||
"beams": 1,
|
"lora_path": config.lora_path,
|
||||||
"decoder_type": None,
|
"lan": config.lan,
|
||||||
"audio_max_len": 20.0,
|
"direct_english_translation": config.direct_english_translation,
|
||||||
"audio_min_len": 0.0,
|
}
|
||||||
"cif_ckpt_path": None,
|
|
||||||
"never_fire": False,
|
if config.transcription:
|
||||||
"init_prompt": None,
|
if config.backend == "voxtral-mlx":
|
||||||
"static_init_prompt": None,
|
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
|
||||||
"max_context_tokens": None,
|
self.tokenizer = None
|
||||||
|
self.asr = VoxtralMLXASR(**transcription_common_params)
|
||||||
|
logger.info("Using Voxtral MLX native backend")
|
||||||
|
elif config.backend == "voxtral":
|
||||||
|
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR
|
||||||
|
self.tokenizer = None
|
||||||
|
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||||
|
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||||
|
elif config.backend_policy == "simulstreaming":
|
||||||
|
simulstreaming_params = {
|
||||||
|
"disable_fast_encoder": config.disable_fast_encoder,
|
||||||
|
"custom_alignment_heads": config.custom_alignment_heads,
|
||||||
|
"frame_threshold": config.frame_threshold,
|
||||||
|
"beams": config.beams,
|
||||||
|
"decoder_type": config.decoder_type,
|
||||||
|
"audio_max_len": config.audio_max_len,
|
||||||
|
"audio_min_len": config.audio_min_len,
|
||||||
|
"cif_ckpt_path": config.cif_ckpt_path,
|
||||||
|
"never_fire": config.never_fire,
|
||||||
|
"init_prompt": config.init_prompt,
|
||||||
|
"static_init_prompt": config.static_init_prompt,
|
||||||
|
"max_context_tokens": config.max_context_tokens,
|
||||||
}
|
}
|
||||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
|
||||||
|
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.asr = SimulStreamingASR(
|
self.asr = SimulStreamingASR(
|
||||||
**transcription_common_params,
|
**transcription_common_params,
|
||||||
**simulstreaming_params,
|
**simulstreaming_params,
|
||||||
backend=self.args.backend,
|
backend=config.backend,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using SimulStreaming policy with %s backend",
|
"Using SimulStreaming policy with %s backend",
|
||||||
getattr(self.asr, "encoder_backend", "whisper"),
|
getattr(self.asr, "encoder_backend", "whisper"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
whisperstreaming_params = {
|
whisperstreaming_params = {
|
||||||
"buffer_trimming": "segment",
|
"buffer_trimming": config.buffer_trimming,
|
||||||
"confidence_validation": False,
|
"confidence_validation": config.confidence_validation,
|
||||||
"buffer_trimming_sec": 15,
|
"buffer_trimming_sec": config.buffer_trimming_sec,
|
||||||
}
|
}
|
||||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
|
||||||
|
|
||||||
self.asr = backend_factory(
|
self.asr = backend_factory(
|
||||||
backend=self.args.backend,
|
backend=config.backend,
|
||||||
**transcription_common_params,
|
**transcription_common_params,
|
||||||
**whisperstreaming_params,
|
**whisperstreaming_params,
|
||||||
)
|
)
|
||||||
@@ -135,60 +145,57 @@ class TranscriptionEngine:
|
|||||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.diarization:
|
if config.diarization:
|
||||||
if self.args.diarization_backend == "diart":
|
if config.diarization_backend == "diart":
|
||||||
from whisperlivekit.diarization.diart_backend import \
|
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||||
DiartDiarization
|
|
||||||
diart_params = {
|
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
|
||||||
"embedding_model": "pyannote/embedding",
|
|
||||||
}
|
|
||||||
diart_params = update_with_kwargs(diart_params, kwargs)
|
|
||||||
self.diarization_model = DiartDiarization(
|
self.diarization_model = DiartDiarization(
|
||||||
block_duration=self.args.min_chunk_size,
|
block_duration=config.min_chunk_size,
|
||||||
**diart_params
|
segmentation_model=config.segmentation_model,
|
||||||
|
embedding_model=config.embedding_model,
|
||||||
)
|
)
|
||||||
elif self.args.diarization_backend == "sortformer":
|
elif config.diarization_backend == "sortformer":
|
||||||
from whisperlivekit.diarization.sortformer_backend import \
|
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||||
SortformerDiarization
|
|
||||||
self.diarization_model = SortformerDiarization()
|
self.diarization_model = SortformerDiarization()
|
||||||
|
|
||||||
self.translation_model = None
|
self.translation_model = None
|
||||||
if self.args.target_language:
|
if config.target_language:
|
||||||
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
if config.lan == 'auto' and config.backend_policy != "simulstreaming":
|
||||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
raise ValueError('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from nllw import load_model
|
from nllw import load_model
|
||||||
except:
|
except ImportError:
|
||||||
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
raise ImportError('To use translation, you must install nllw: `pip install nllw`')
|
||||||
translation_params = {
|
self.translation_model = load_model(
|
||||||
"nllb_backend": "transformers",
|
[config.lan],
|
||||||
"nllb_size": "600M"
|
nllb_backend=config.nllb_backend,
|
||||||
}
|
nllb_size=config.nllb_size,
|
||||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
)
|
||||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
|
||||||
TranscriptionEngine._initialized = True
|
|
||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr):
|
||||||
|
if getattr(args, 'backend', None) == "voxtral-mlx":
|
||||||
|
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
||||||
|
return VoxtralMLXOnlineProcessor(asr)
|
||||||
|
if getattr(args, 'backend', None) == "voxtral":
|
||||||
|
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
|
||||||
|
return VoxtralHFStreamingOnlineProcessor(asr)
|
||||||
if args.backend_policy == "simulstreaming":
|
if args.backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
online = SimulStreamingOnlineProcessor(asr)
|
return SimulStreamingOnlineProcessor(asr)
|
||||||
else:
|
return OnlineASRProcessor(asr)
|
||||||
online = OnlineASRProcessor(asr)
|
|
||||||
return online
|
|
||||||
|
|
||||||
|
|
||||||
def online_diarization_factory(args, diarization_backend):
|
def online_diarization_factory(args, diarization_backend):
|
||||||
if args.diarization_backend == "diart":
|
if args.diarization_backend == "diart":
|
||||||
online = diarization_backend
|
online = diarization_backend
|
||||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||||
|
elif args.diarization_backend == "sortformer":
|
||||||
if args.diarization_backend == "sortformer":
|
|
||||||
from whisperlivekit.diarization.sortformer_backend import \
|
from whisperlivekit.diarization.sortformer_backend import \
|
||||||
SortformerDiarizationOnline
|
SortformerDiarizationOnline
|
||||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")
|
||||||
return online
|
return online
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Empty, SimpleQueue
|
from queue import Empty, SimpleQueue
|
||||||
@@ -14,14 +13,11 @@ from diart.sources import AudioSource, MicrophoneAudioSource
|
|||||||
from pyannote.core import Annotation
|
from pyannote.core import Annotation
|
||||||
from rx.core import Observer
|
from rx.core import Observer
|
||||||
|
|
||||||
|
from whisperlivekit.diarization.utils import extract_number
|
||||||
from whisperlivekit.timed_objects import SpeakerSegment
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_number(s: str) -> int:
|
|
||||||
m = re.search(r'\d+', s)
|
|
||||||
return int(m.group()) if m else None
|
|
||||||
|
|
||||||
class DiarizationObserver(Observer):
|
class DiarizationObserver(Observer):
|
||||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||||
|
|
||||||
@@ -202,14 +198,14 @@ class DiartDiarization:
|
|||||||
def insert_silence(self, silence_duration):
|
def insert_silence(self, silence_duration):
|
||||||
self.observer.global_time_offset += silence_duration
|
self.observer.global_time_offset += silence_duration
|
||||||
|
|
||||||
async def diarize(self, pcm_array: np.ndarray):
|
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||||
"""
|
"""Buffer audio for the next diarization step."""
|
||||||
Process audio data for diarization.
|
|
||||||
Only used when working with WebSocketAudioSource.
|
|
||||||
"""
|
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
# self.observer.clear_old_segments()
|
|
||||||
|
async def diarize(self):
|
||||||
|
"""Return the current speaker segments from the diarization pipeline."""
|
||||||
|
return self.observer.get_segments()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
|
|||||||
@@ -287,11 +287,7 @@ class SortformerDiarizationOnline:
|
|||||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||||
|
|
||||||
|
|
||||||
def extract_number(s: str) -> int:
|
from whisperlivekit.diarization.utils import extract_number
|
||||||
"""Extract number from speaker string (compatibility function)."""
|
|
||||||
import re
|
|
||||||
m = re.search(r'\d+', s)
|
|
||||||
return int(m.group()) if m else 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
7
whisperlivekit/diarization/utils.py
Normal file
7
whisperlivekit/diarization/utils.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_number(s: str) -> int:
|
||||||
|
"""Extract the first integer from a string, e.g. 'speaker_2' -> 2."""
|
||||||
|
m = re.search(r'\d+', s)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
@@ -7,7 +7,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||||
|
|
||||||
@@ -16,22 +16,16 @@ class ASRBase:
|
|||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
# "" for faster-whisper because it emits the spaces when needed)
|
# "" for faster-whisper because it emits the spaces when needed)
|
||||||
|
|
||||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.transcribe_kargs = {}
|
self.transcribe_kargs = {}
|
||||||
|
self.lora_path = lora_path
|
||||||
if lan == "auto":
|
if lan == "auto":
|
||||||
self.original_language = None
|
self.original_language = None
|
||||||
else:
|
else:
|
||||||
self.original_language = lan
|
self.original_language = lan
|
||||||
self.model = self.load_model(model_size, cache_dir, model_dir)
|
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> ASRToken:
|
|
||||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
|
||||||
|
|
||||||
def load_model(self, model_size, cache_dir, model_dir):
|
def load_model(self, model_size, cache_dir, model_dir):
|
||||||
raise NotImplementedError("must be implemented in the child class")
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
@@ -47,24 +41,23 @@ class WhisperASR(ASRBase):
|
|||||||
sep = " "
|
sep = " "
|
||||||
|
|
||||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
from whisperlivekit.whisper import load_model as load_model
|
from whisperlivekit.whisper import load_model as load_whisper_model
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
resolved_path = resolve_model_path(model_dir)
|
resolved_path = resolve_model_path(model_dir)
|
||||||
if resolved_path.is_dir():
|
if resolved_path.is_dir():
|
||||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
model_info = detect_model_format(resolved_path)
|
||||||
if pytorch_path is None:
|
if not model_info.has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||||
)
|
)
|
||||||
resolved_path = pytorch_path
|
|
||||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||||
return load_model(str(resolved_path))
|
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
||||||
|
|
||||||
if model_size is None:
|
if model_size is None:
|
||||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||||
|
|
||||||
return load_model(model_size, download_root=cache_dir)
|
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
def transcribe(self, audio, init_prompt=""):
|
||||||
options = dict(self.transcribe_kargs)
|
options = dict(self.transcribe_kargs)
|
||||||
@@ -187,22 +180,8 @@ class MLXWhisper(ASRBase):
|
|||||||
return transcribe
|
return transcribe
|
||||||
|
|
||||||
def translate_model_name(self, model_name):
|
def translate_model_name(self, model_name):
|
||||||
model_mapping = {
|
from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
|
||||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
mlx_model_path = MLX_MODEL_MAPPING.get(model_name)
|
||||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
|
||||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
|
||||||
"base": "mlx-community/whisper-base-mlx",
|
|
||||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
|
||||||
"small": "mlx-community/whisper-small-mlx",
|
|
||||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
|
||||||
"medium": "mlx-community/whisper-medium-mlx",
|
|
||||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
|
||||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
|
||||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
|
||||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
|
||||||
"large": "mlx-community/whisper-large-mlx",
|
|
||||||
}
|
|
||||||
mlx_model_path = model_mapping.get(model_name)
|
|
||||||
if mlx_model_path:
|
if mlx_model_path:
|
||||||
return mlx_model_path
|
return mlx_model_path
|
||||||
else:
|
else:
|
||||||
@@ -227,7 +206,6 @@ class MLXWhisper(ASRBase):
|
|||||||
if segment.get("no_speech_prob", 0) > 0.9:
|
if segment.get("no_speech_prob", 0) > 0.9:
|
||||||
continue
|
continue
|
||||||
for word in segment.get("words", []):
|
for word in segment.get("words", []):
|
||||||
probability=word["probability"]
|
|
||||||
token = ASRToken(word["start"], word["end"], word["word"])
|
token = ASRToken(word["start"], word["end"], word["word"])
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return tokens
|
return tokens
|
||||||
@@ -238,6 +216,7 @@ class MLXWhisper(ASRBase):
|
|||||||
def use_vad(self):
|
def use_vad(self):
|
||||||
self.transcribe_kargs["vad_filter"] = True
|
self.transcribe_kargs["vad_filter"] = True
|
||||||
|
|
||||||
|
|
||||||
class OpenaiApiASR(ASRBase):
|
class OpenaiApiASR(ASRBase):
|
||||||
"""Uses OpenAI's Whisper API for transcription."""
|
"""Uses OpenAI's Whisper API for transcription."""
|
||||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||||
@@ -249,6 +228,7 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.load_model()
|
self.load_model()
|
||||||
self.use_vad_opt = False
|
self.use_vad_opt = False
|
||||||
self.direct_english_translation = False
|
self.direct_english_translation = False
|
||||||
|
self.task = "transcribe"
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -294,7 +274,8 @@ class OpenaiApiASR(ASRBase):
|
|||||||
params["language"] = self.original_language
|
params["language"] = self.original_language
|
||||||
if prompt:
|
if prompt:
|
||||||
params["prompt"] = prompt
|
params["prompt"] = prompt
|
||||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
task = self.transcribe_kargs.get("task", self.task)
|
||||||
|
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||||
transcript = proc.create(**params)
|
transcript = proc.create(**params)
|
||||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||||
return transcript
|
return transcript
|
||||||
|
|||||||
@@ -136,6 +136,11 @@ class OnlineASRProcessor:
|
|||||||
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
"""Handle speaker change event."""
|
||||||
|
self.process_iter()
|
||||||
|
self.init(offset=change_speaker.start)
|
||||||
|
|
||||||
def init(self, offset: Optional[float] = None):
|
def init(self, offset: Optional[float] = None):
|
||||||
"""Initialize or reset the processing buffers."""
|
"""Initialize or reset the processing buffers."""
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
mlx_backend_available)
|
mlx_backend_available)
|
||||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||||
from whisperlivekit.warmup import warmup_asr
|
from whisperlivekit.warmup import warmup_asr
|
||||||
|
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
||||||
@@ -77,6 +72,7 @@ def backend_factory(
|
|||||||
model_cache_dir,
|
model_cache_dir,
|
||||||
model_dir,
|
model_dir,
|
||||||
model_path,
|
model_path,
|
||||||
|
lora_path,
|
||||||
direct_english_translation,
|
direct_english_translation,
|
||||||
buffer_trimming,
|
buffer_trimming,
|
||||||
buffer_trimming_sec,
|
buffer_trimming_sec,
|
||||||
@@ -87,16 +83,20 @@ def backend_factory(
|
|||||||
backend_choice = backend
|
backend_choice = backend
|
||||||
custom_reference = model_path or model_dir
|
custom_reference = model_path or model_dir
|
||||||
resolved_root = None
|
resolved_root = None
|
||||||
pytorch_checkpoint = None
|
|
||||||
has_mlx_weights = False
|
has_mlx_weights = False
|
||||||
has_fw_weights = False
|
has_fw_weights = False
|
||||||
|
has_pytorch = False
|
||||||
|
|
||||||
if custom_reference:
|
if custom_reference:
|
||||||
resolved_root = resolve_model_path(custom_reference)
|
resolved_root = resolve_model_path(custom_reference)
|
||||||
if resolved_root.is_dir():
|
if resolved_root.is_dir():
|
||||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
model_info = detect_model_format(resolved_root)
|
||||||
|
has_mlx_weights = model_info.compatible_whisper_mlx
|
||||||
|
has_fw_weights = model_info.compatible_faster_whisper
|
||||||
|
has_pytorch = model_info.has_pytorch
|
||||||
else:
|
else:
|
||||||
pytorch_checkpoint = resolved_root
|
# Single file provided
|
||||||
|
has_pytorch = True
|
||||||
|
|
||||||
if backend_choice == "openai-api":
|
if backend_choice == "openai-api":
|
||||||
logger.debug("Using OpenAI API.")
|
logger.debug("Using OpenAI API.")
|
||||||
@@ -121,8 +121,8 @@ def backend_factory(
|
|||||||
model_override = str(resolved_root) if resolved_root is not None else None
|
model_override = str(resolved_root) if resolved_root is not None else None
|
||||||
else:
|
else:
|
||||||
asr_cls = WhisperASR
|
asr_cls = WhisperASR
|
||||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
model_override = str(resolved_root) if resolved_root is not None else None
|
||||||
if custom_reference and model_override is None:
|
if custom_reference and not has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||||
)
|
)
|
||||||
@@ -134,12 +134,14 @@ def backend_factory(
|
|||||||
lan=lan,
|
lan=lan,
|
||||||
cache_dir=model_cache_dir,
|
cache_dir=model_cache_dir,
|
||||||
model_dir=model_override,
|
model_dir=model_override,
|
||||||
|
lora_path=lora_path if backend_choice == "whisper" else None,
|
||||||
)
|
)
|
||||||
e = time.time()
|
e = time.time()
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||||
|
|
||||||
if direct_english_translation:
|
if direct_english_translation:
|
||||||
tgt_language = "en" # Whisper translates into English
|
tgt_language = "en" # Whisper translates into English
|
||||||
|
asr.transcribe_kargs["task"] = "translate"
|
||||||
else:
|
else:
|
||||||
tgt_language = lan # Whisper transcribes in this language
|
tgt_language = lan # Whisper transcribes in this language
|
||||||
|
|
||||||
|
|||||||
156
whisperlivekit/metrics.py
Normal file
156
whisperlivekit/metrics.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Lightweight ASR evaluation metrics — no external dependencies.
|
||||||
|
|
||||||
|
Provides WER (Word Error Rate) computation via word-level Levenshtein distance,
|
||||||
|
text normalization, and word-level timestamp accuracy metrics with greedy alignment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text(text: str) -> str:
|
||||||
|
"""Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace."""
|
||||||
|
text = text.lower()
|
||||||
|
# Normalize unicode (e.g., accented chars to composed form)
|
||||||
|
text = unicodedata.normalize("NFC", text)
|
||||||
|
# Remove punctuation (keep letters, numbers, spaces, hyphens within words)
|
||||||
|
text = re.sub(r"[^\w\s\-']", " ", text)
|
||||||
|
# Collapse whitespace
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def compute_wer(reference: str, hypothesis: str) -> Dict:
|
||||||
|
"""Compute Word Error Rate using word-level Levenshtein edit distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference: Ground truth transcription.
|
||||||
|
hypothesis: Predicted transcription.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words.
|
||||||
|
WER can exceed 1.0 if there are more errors than reference words.
|
||||||
|
"""
|
||||||
|
ref_words = normalize_text(reference).split()
|
||||||
|
hyp_words = normalize_text(hypothesis).split()
|
||||||
|
|
||||||
|
n = len(ref_words)
|
||||||
|
m = len(hyp_words)
|
||||||
|
|
||||||
|
if n == 0:
|
||||||
|
return {
|
||||||
|
"wer": 0.0 if m == 0 else float(m),
|
||||||
|
"substitutions": 0,
|
||||||
|
"insertions": m,
|
||||||
|
"deletions": 0,
|
||||||
|
"ref_words": 0,
|
||||||
|
"hyp_words": m,
|
||||||
|
}
|
||||||
|
|
||||||
|
# DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions)
|
||||||
|
dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)]
|
||||||
|
|
||||||
|
for i in range(1, n + 1):
|
||||||
|
dp[i][0] = (i, 0, 0, i)
|
||||||
|
for j in range(1, m + 1):
|
||||||
|
dp[0][j] = (j, 0, j, 0)
|
||||||
|
|
||||||
|
for i in range(1, n + 1):
|
||||||
|
for j in range(1, m + 1):
|
||||||
|
if ref_words[i - 1] == hyp_words[j - 1]:
|
||||||
|
dp[i][j] = dp[i - 1][j - 1]
|
||||||
|
else:
|
||||||
|
sub = dp[i - 1][j - 1]
|
||||||
|
ins = dp[i][j - 1]
|
||||||
|
dele = dp[i - 1][j]
|
||||||
|
|
||||||
|
sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3])
|
||||||
|
ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3])
|
||||||
|
del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1)
|
||||||
|
|
||||||
|
dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0])
|
||||||
|
|
||||||
|
dist, subs, ins, dels = dp[n][m]
|
||||||
|
return {
|
||||||
|
"wer": dist / n,
|
||||||
|
"substitutions": subs,
|
||||||
|
"insertions": ins,
|
||||||
|
"deletions": dels,
|
||||||
|
"ref_words": n,
|
||||||
|
"hyp_words": m,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_timestamp_accuracy(
|
||||||
|
predicted: List[Dict],
|
||||||
|
reference: List[Dict],
|
||||||
|
) -> Dict:
|
||||||
|
"""Compute timestamp accuracy by aligning predicted words to reference words.
|
||||||
|
|
||||||
|
Uses greedy left-to-right alignment on normalized text. For each matched pair,
|
||||||
|
computes the start-time delta (predicted - reference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicted: List of dicts with keys: word, start, end.
|
||||||
|
reference: List of dicts with keys: word, start, end.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: mae_start, max_delta_start, median_delta_start,
|
||||||
|
n_matched, n_ref, n_pred. Returns None values if no matches found.
|
||||||
|
"""
|
||||||
|
if not predicted or not reference:
|
||||||
|
return {
|
||||||
|
"mae_start": None,
|
||||||
|
"max_delta_start": None,
|
||||||
|
"median_delta_start": None,
|
||||||
|
"n_matched": 0,
|
||||||
|
"n_ref": len(reference),
|
||||||
|
"n_pred": len(predicted),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Normalize words for matching
|
||||||
|
pred_norm = [normalize_text(p["word"]) for p in predicted]
|
||||||
|
ref_norm = [normalize_text(r["word"]) for r in reference]
|
||||||
|
|
||||||
|
# Greedy left-to-right alignment
|
||||||
|
deltas_start = []
|
||||||
|
ref_idx = 0
|
||||||
|
for p_idx, p_word in enumerate(pred_norm):
|
||||||
|
if not p_word:
|
||||||
|
continue
|
||||||
|
# Scan forward in reference to find a match (allow small skips)
|
||||||
|
search_limit = min(ref_idx + 3, len(ref_norm))
|
||||||
|
for r_idx in range(ref_idx, search_limit):
|
||||||
|
if ref_norm[r_idx] == p_word:
|
||||||
|
delta = predicted[p_idx]["start"] - reference[r_idx]["start"]
|
||||||
|
deltas_start.append(delta)
|
||||||
|
ref_idx = r_idx + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if not deltas_start:
|
||||||
|
return {
|
||||||
|
"mae_start": None,
|
||||||
|
"max_delta_start": None,
|
||||||
|
"median_delta_start": None,
|
||||||
|
"n_matched": 0,
|
||||||
|
"n_ref": len(reference),
|
||||||
|
"n_pred": len(predicted),
|
||||||
|
}
|
||||||
|
|
||||||
|
abs_deltas = [abs(d) for d in deltas_start]
|
||||||
|
sorted_abs = sorted(abs_deltas)
|
||||||
|
n = len(sorted_abs)
|
||||||
|
if n % 2 == 1:
|
||||||
|
median = sorted_abs[n // 2]
|
||||||
|
else:
|
||||||
|
median = (sorted_abs[n // 2 - 1] + sorted_abs[n // 2]) / 2
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mae_start": sum(abs_deltas) / len(abs_deltas),
|
||||||
|
"max_delta_start": max(abs_deltas),
|
||||||
|
"median_delta_start": median,
|
||||||
|
"n_matched": len(deltas_start),
|
||||||
|
"n_ref": len(reference),
|
||||||
|
"n_pred": len(predicted),
|
||||||
|
}
|
||||||
84
whisperlivekit/metrics_collector.py
Normal file
84
whisperlivekit/metrics_collector.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Lightweight runtime metrics for AudioProcessor sessions.
|
||||||
|
|
||||||
|
Zero external dependencies. Negligible overhead when not queried —
|
||||||
|
just integer increments and list appends during normal operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionMetrics:
|
||||||
|
"""Per-session metrics collected by AudioProcessor."""
|
||||||
|
|
||||||
|
session_start: float = 0.0
|
||||||
|
total_audio_duration_s: float = 0.0
|
||||||
|
total_processing_time_s: float = 0.0
|
||||||
|
|
||||||
|
# Chunk / call counters
|
||||||
|
n_chunks_received: int = 0
|
||||||
|
n_transcription_calls: int = 0
|
||||||
|
n_tokens_produced: int = 0
|
||||||
|
n_responses_sent: int = 0
|
||||||
|
|
||||||
|
# Per-call ASR latency (seconds)
|
||||||
|
transcription_durations: List[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Silence
|
||||||
|
n_silence_events: int = 0
|
||||||
|
total_silence_duration_s: float = 0.0
|
||||||
|
|
||||||
|
# --- Computed properties ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rtf(self) -> float:
|
||||||
|
"""Real-time factor: processing_time / audio_duration."""
|
||||||
|
if self.total_audio_duration_s <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.total_processing_time_s / self.total_audio_duration_s
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_latency_ms(self) -> float:
|
||||||
|
"""Average per-call ASR latency in milliseconds."""
|
||||||
|
if not self.transcription_durations:
|
||||||
|
return 0.0
|
||||||
|
return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p95_latency_ms(self) -> float:
|
||||||
|
"""95th percentile per-call ASR latency in milliseconds."""
|
||||||
|
if not self.transcription_durations:
|
||||||
|
return 0.0
|
||||||
|
sorted_d = sorted(self.transcription_durations)
|
||||||
|
idx = int(len(sorted_d) * 0.95)
|
||||||
|
idx = min(idx, len(sorted_d) - 1)
|
||||||
|
return sorted_d[idx] * 1000
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Serialize to a plain dict (JSON-safe)."""
|
||||||
|
return {
|
||||||
|
"session_start": self.session_start,
|
||||||
|
"total_audio_duration_s": round(self.total_audio_duration_s, 3),
|
||||||
|
"total_processing_time_s": round(self.total_processing_time_s, 3),
|
||||||
|
"rtf": round(self.rtf, 3),
|
||||||
|
"n_chunks_received": self.n_chunks_received,
|
||||||
|
"n_transcription_calls": self.n_transcription_calls,
|
||||||
|
"n_tokens_produced": self.n_tokens_produced,
|
||||||
|
"n_responses_sent": self.n_responses_sent,
|
||||||
|
"avg_latency_ms": round(self.avg_latency_ms, 2),
|
||||||
|
"p95_latency_ms": round(self.p95_latency_ms, 2),
|
||||||
|
"n_silence_events": self.n_silence_events,
|
||||||
|
"total_silence_duration_s": round(self.total_silence_duration_s, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
def log_summary(self) -> None:
|
||||||
|
"""Emit a structured log line summarising the session."""
|
||||||
|
self.total_processing_time_s = sum(self.transcription_durations)
|
||||||
|
d = self.to_dict()
|
||||||
|
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
|
||||||
|
logger.info(f"SESSION_METRICS {d}")
|
||||||
17
whisperlivekit/model_mapping.py
Normal file
17
whisperlivekit/model_mapping.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Shared MLX model name mapping used by both SimulStreaming and LocalAgreement backends."""
|
||||||
|
|
||||||
|
MLX_MODEL_MAPPING = {
|
||||||
|
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||||
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||||
|
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||||
|
"base": "mlx-community/whisper-base-mlx",
|
||||||
|
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||||
|
"small": "mlx-community/whisper-small-mlx",
|
||||||
|
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||||
|
"medium": "mlx-community/whisper-medium-mlx",
|
||||||
|
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||||
|
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||||
|
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||||
|
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||||
|
"large": "mlx-community/whisper-large-mlx",
|
||||||
|
}
|
||||||
@@ -1,49 +1,195 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Information about detected model format and files in a directory."""
|
||||||
|
path: Optional[Path] = None
|
||||||
|
pytorch_files: List[Path] = field(default_factory=list)
|
||||||
|
compatible_whisper_mlx: bool = False
|
||||||
|
compatible_faster_whisper: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_pytorch(self) -> bool:
|
||||||
|
return len(self.pytorch_files) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_sharded(self) -> bool:
|
||||||
|
return len(self.pytorch_files) > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def primary_pytorch_file(self) -> Optional[Path]:
|
||||||
|
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
||||||
|
if not self.pytorch_files:
|
||||||
|
return None
|
||||||
|
return self.pytorch_files[0]
|
||||||
|
|
||||||
|
|
||||||
|
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
|
||||||
|
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
|
||||||
|
|
||||||
|
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
|
||||||
|
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
|
||||||
|
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
||||||
|
|
||||||
|
CTranslate2 models have specific companion files that distinguish them
|
||||||
|
from PyTorch .bin files.
|
||||||
|
"""
|
||||||
|
n_indicators = 0
|
||||||
|
for indicator in CT2_INDICATOR_FILES: #test 1
|
||||||
|
if (directory / indicator).exists():
|
||||||
|
n_indicators += 1
|
||||||
|
|
||||||
|
if n_indicators == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
config_path = directory / "config.json" #test 2
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
if config.get("model_type") == "whisper": #test 2
|
||||||
|
return False
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Collect all PyTorch checkpoint files from a directory.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
||||||
|
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
||||||
|
- Index-based sharded models (reads index file to find shards)
|
||||||
|
|
||||||
|
Returns files sorted appropriately (shards in order, or single file).
|
||||||
|
"""
|
||||||
|
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
||||||
|
index_path = directory / index_name
|
||||||
|
if index_path.exists():
|
||||||
|
try:
|
||||||
|
with open(index_path, "r", encoding="utf-8") as f:
|
||||||
|
index_data = json.load(f)
|
||||||
|
weight_map = index_data.get("weight_map", {})
|
||||||
|
if weight_map:
|
||||||
|
shard_names = sorted(set(weight_map.values()))
|
||||||
|
shards = [directory / name for name in shard_names if (directory / name).exists()]
|
||||||
|
if shards:
|
||||||
|
return shards
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
sharded_groups = {}
|
||||||
|
single_files = {}
|
||||||
|
|
||||||
|
for file in directory.iterdir():
|
||||||
|
if not file.is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
filename = file.name
|
||||||
|
suffix = file.suffix.lower()
|
||||||
|
|
||||||
|
if filename.startswith("adapter_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = SHARDED_PATTERN.match(filename)
|
||||||
|
if match:
|
||||||
|
base_name, shard_idx, total_shards, ext = match.groups()
|
||||||
|
key = (base_name, ext, int(total_shards))
|
||||||
|
if key not in sharded_groups:
|
||||||
|
sharded_groups[key] = []
|
||||||
|
sharded_groups[key].append((int(shard_idx), file))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if filename == "model.safetensors":
|
||||||
|
single_files[0] = file # Highest priority
|
||||||
|
elif filename == "pytorch_model.bin":
|
||||||
|
single_files[1] = file
|
||||||
|
elif suffix == ".pt":
|
||||||
|
single_files[2] = file
|
||||||
|
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
||||||
|
single_files[3] = file
|
||||||
|
|
||||||
|
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
||||||
|
if len(shards) == total_shards:
|
||||||
|
return [path for _, path in sorted(shards)]
|
||||||
|
|
||||||
|
for priority in sorted(single_files.keys()):
|
||||||
|
return [single_files[priority]]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Detect the model format in a given path.
|
||||||
|
|
||||||
|
This function analyzes a file or directory to determine:
|
||||||
|
- What PyTorch checkpoint files are available (including sharded models)
|
||||||
|
- Whether the directory contains MLX Whisper weights
|
||||||
|
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to a model file or directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo with detected format information
|
||||||
|
"""
|
||||||
|
path = Path(model_path)
|
||||||
|
info = ModelInfo(path=path)
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix in {".pt", ".safetensors", ".bin"}:
|
||||||
|
info.pytorch_files = [path]
|
||||||
|
return info
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
return info
|
||||||
|
|
||||||
|
for file in path.iterdir():
|
||||||
|
if not file.is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
filename = file.name.lower()
|
||||||
|
|
||||||
|
if filename in MLX_WHISPER_MARKERS:
|
||||||
|
info.compatible_whisper_mlx = True
|
||||||
|
|
||||||
|
if filename in FASTER_WHISPER_MARKERS:
|
||||||
|
if _is_ct2_model_bin(path, filename):
|
||||||
|
info.compatible_faster_whisper = True
|
||||||
|
|
||||||
|
info.pytorch_files = _collect_pytorch_files(path)
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||||
"""
|
"""
|
||||||
Inspect the provided path and determine which model formats are available.
|
Inspect the provided path and determine which model formats are available.
|
||||||
|
|
||||||
|
This is a compatibility wrapper around detect_model_format().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
||||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
|
||||||
"""
|
"""
|
||||||
path = Path(model_path)
|
info = detect_model_format(model_path)
|
||||||
|
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
|
||||||
compatible_whisper_mlx = False
|
|
||||||
compatible_faster_whisper = False
|
|
||||||
pytorch_path: Optional[Path] = None
|
|
||||||
|
|
||||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
|
||||||
pytorch_path = path
|
|
||||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
|
||||||
|
|
||||||
if path.is_dir():
|
|
||||||
for file in path.iterdir():
|
|
||||||
if not file.is_file():
|
|
||||||
continue
|
|
||||||
|
|
||||||
filename = file.name.lower()
|
|
||||||
suffix = file.suffix.lower()
|
|
||||||
|
|
||||||
if filename in {"weights.npz", "weights.safetensors"}:
|
|
||||||
compatible_whisper_mlx = True
|
|
||||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
|
||||||
compatible_faster_whisper = True
|
|
||||||
elif suffix in {".pt", ".safetensors"}:
|
|
||||||
pytorch_path = file
|
|
||||||
elif filename == "pytorch_model.bin":
|
|
||||||
pytorch_path = file
|
|
||||||
|
|
||||||
if pytorch_path is None:
|
|
||||||
fallback = path / "pytorch_model.bin"
|
|
||||||
if fallback.exists():
|
|
||||||
pytorch_path = fallback
|
|
||||||
|
|
||||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||||
@@ -59,7 +205,7 @@ def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
except ImportError as exc:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||||
"is not installed to download it."
|
"is not installed to download it."
|
||||||
|
|||||||
@@ -106,6 +106,13 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
dest="lora_path",
|
||||||
|
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lan",
|
"--lan",
|
||||||
"--language",
|
"--language",
|
||||||
@@ -140,8 +147,8 @@ def parse_args():
|
|||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
|
||||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-vac",
|
"--no-vac",
|
||||||
@@ -311,15 +318,12 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.transcription = not args.no_transcription
|
args.transcription = not args.no_transcription
|
||||||
args.vad = not args.no_vad
|
args.vad = not args.no_vad
|
||||||
|
args.vac = not args.no_vac
|
||||||
delattr(args, 'no_transcription')
|
delattr(args, 'no_transcription')
|
||||||
delattr(args, 'no_vad')
|
delattr(args, 'no_vad')
|
||||||
|
delattr(args, 'no_vac')
|
||||||
|
|
||||||
if args.backend_policy == "1":
|
from whisperlivekit.config import WhisperLiveKitConfig
|
||||||
args.backend_policy = "simulstreaming"
|
return WhisperLiveKitConfig.from_namespace(args)
|
||||||
elif args.backend_policy == "2":
|
|
||||||
args.backend_policy = "localagreement"
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|||||||
@@ -8,6 +8,15 @@ import torch
|
|||||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def is_onnx_available() -> bool:
|
||||||
|
"""Check if onnxruntime is installed."""
|
||||||
|
try:
|
||||||
|
import onnxruntime
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||||
"""Load a JIT model from file."""
|
"""Load a JIT model from file."""
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
@@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class OnnxWrapper():
|
class OnnxSession():
|
||||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
"""
|
||||||
|
Shared ONNX session for Silero VAD model (stateless).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, path, force_onnx_cpu=False):
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
global np
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
|
||||||
opts = onnxruntime.SessionOptions()
|
opts = onnxruntime.SessionOptions()
|
||||||
@@ -32,13 +41,28 @@ class OnnxWrapper():
|
|||||||
else:
|
else:
|
||||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
self.reset_states()
|
self.path = path
|
||||||
if '16k' in path:
|
if '16k' in path:
|
||||||
warnings.warn('This model support only 16000 sampling rate!')
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
self.sample_rates = [16000]
|
self.sample_rates = [16000]
|
||||||
else:
|
else:
|
||||||
self.sample_rates = [8000, 16000]
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxWrapper():
|
||||||
|
"""
|
||||||
|
ONNX Runtime wrapper for Silero VAD model with per-instance state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
|
||||||
|
self._shared_session = session
|
||||||
|
self.sample_rates = session.sample_rates
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
return self._shared_session.session
|
||||||
|
|
||||||
def _validate_input(self, x, sr: int):
|
def _validate_input(self, x, sr: int):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
@@ -91,7 +115,7 @@ class OnnxWrapper():
|
|||||||
out, state = ort_outs
|
out, state = ort_outs
|
||||||
self._state = torch.from_numpy(state)
|
self._state = torch.from_numpy(state)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError(f"Unsupported sampling rate {sr}. Supported: {self.sample_rates} (with sample sizes 256 for 8000, 512 for 16000)")
|
||||||
|
|
||||||
self._context = x[..., -context_size:]
|
self._context = x[..., -context_size:]
|
||||||
self._last_sr = sr
|
self._last_sr = sr
|
||||||
@@ -101,37 +125,49 @@ class OnnxWrapper():
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
|
||||||
"""
|
"""Get the path to the ONNX model file."""
|
||||||
Load Silero VAD model (JIT or ONNX).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model_path : str, optional
|
|
||||||
Path to model file. If None, uses default bundled model.
|
|
||||||
onnx : bool, default False
|
|
||||||
Whether to use ONNX runtime (requires onnxruntime package).
|
|
||||||
opset_version : int, default 16
|
|
||||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
model
|
|
||||||
Loaded VAD model (JIT or ONNX wrapper)
|
|
||||||
"""
|
|
||||||
available_ops = [15, 16]
|
available_ops = [15, 16]
|
||||||
if onnx and opset_version not in available_ops:
|
if opset_version not in available_ops:
|
||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}')
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
data_dir = current_dir / 'silero_vad_models'
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
|
|
||||||
if onnx:
|
|
||||||
if opset_version == 16:
|
if opset_version == 16:
|
||||||
model_name = 'silero_vad.onnx'
|
model_name = 'silero_vad.onnx'
|
||||||
else:
|
else:
|
||||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
|
||||||
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model file not found: {model_path}\n"
|
||||||
|
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
|
||||||
|
"""
|
||||||
|
Load a shared ONNX session for Silero VAD.
|
||||||
|
"""
|
||||||
|
path = _get_onnx_model_path(model_path, opset_version)
|
||||||
|
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def load_jit_vad(model_path: str = None):
|
||||||
|
"""
|
||||||
|
Load Silero VAD model in JIT format.
|
||||||
|
"""
|
||||||
|
if model_path is None:
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
model_name = 'silero_vad.jit'
|
model_name = 'silero_vad.jit'
|
||||||
|
|
||||||
model_path = data_dir / model_name
|
model_path = data_dir / model_name
|
||||||
@@ -143,15 +179,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
if onnx:
|
|
||||||
try:
|
|
||||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
|
||||||
"Or use JIT model by setting onnx=False"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = init_jit_model(str(model_path))
|
model = init_jit_model(str(model_path))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -227,8 +255,8 @@ class VADIterator:
|
|||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
try:
|
try:
|
||||||
x = torch.Tensor(x)
|
x = torch.Tensor(x)
|
||||||
except:
|
except (ValueError, TypeError, RuntimeError) as exc:
|
||||||
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
raise TypeError("Audio cannot be cast to tensor. Cast it manually") from exc
|
||||||
|
|
||||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||||
self.current_sample += window_size_samples
|
self.current_sample += window_size_samples
|
||||||
@@ -285,8 +313,8 @@ class FixedVADIterator(VADIterator):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = load_silero_vad(onnx=False)
|
# vad = FixedVADIterator(load_jit_vad())
|
||||||
vad = FixedVADIterator(model)
|
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
|
||||||
|
|
||||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
@@ -295,3 +323,4 @@ if __name__ == "__main__":
|
|||||||
# test with 511 samples
|
# test with 511 samples
|
||||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
|
print(f" 511 samples: {result}")
|
||||||
552
whisperlivekit/simul_whisper/align_att_base.py
Normal file
552
whisperlivekit/simul_whisper/align_att_base.py
Normal file
@@ -0,0 +1,552 @@
|
|||||||
|
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||||
|
|
||||||
|
from .config import AlignAttConfig
|
||||||
|
|
||||||
|
DEC_PAD = 50257
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AlignAttBase(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for AlignAtt streaming decoders.
|
||||||
|
|
||||||
|
Provides shared logic for both PyTorch and MLX implementations:
|
||||||
|
- Properties (speaker, global_time_offset)
|
||||||
|
- Pure-Python methods (warmup, trim_context, refresh_segment, etc.)
|
||||||
|
- Template infer() with abstract hooks for tensor-specific operations
|
||||||
|
- Post-decode logic (token splitting, timestamped word building)
|
||||||
|
|
||||||
|
Subclasses must implement ~20 abstract methods for tensor-specific ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# === Properties ===
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speaker(self):
|
||||||
|
return self.state.speaker
|
||||||
|
|
||||||
|
@speaker.setter
|
||||||
|
def speaker(self, value):
|
||||||
|
self.state.speaker = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_time_offset(self):
|
||||||
|
return self.state.global_time_offset
|
||||||
|
|
||||||
|
@global_time_offset.setter
|
||||||
|
def global_time_offset(self, value):
|
||||||
|
self.state.global_time_offset = value
|
||||||
|
|
||||||
|
# === Constructor helpers ===
|
||||||
|
|
||||||
|
def _base_init(self, cfg: AlignAttConfig, model):
|
||||||
|
"""Common initialization — call from subclass __init__."""
|
||||||
|
self.model = model
|
||||||
|
self.cfg = cfg
|
||||||
|
self.decode_options = DecodingOptions(
|
||||||
|
language=cfg.language,
|
||||||
|
without_timestamps=True,
|
||||||
|
task=cfg.task,
|
||||||
|
)
|
||||||
|
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||||
|
self.max_text_len = model.dims.n_text_ctx
|
||||||
|
self.num_decoder_layers = len(model.decoder.blocks)
|
||||||
|
if cfg.max_context_tokens is None:
|
||||||
|
self.max_context_tokens = self.max_text_len
|
||||||
|
else:
|
||||||
|
self.max_context_tokens = cfg.max_context_tokens
|
||||||
|
|
||||||
|
def _init_state_common(self, cfg: AlignAttConfig):
|
||||||
|
"""Common state initialization — call from subclass _init_state."""
|
||||||
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
|
self.state.global_time_offset = 0.0
|
||||||
|
self.state.last_attend_frame = -cfg.rewind_threshold
|
||||||
|
self.state.speaker = -1
|
||||||
|
|
||||||
|
# === Shared concrete methods ===
|
||||||
|
|
||||||
|
def warmup(self, audio):
|
||||||
|
try:
|
||||||
|
self.insert_audio(audio)
|
||||||
|
self.infer(is_last=True)
|
||||||
|
self.refresh_segment(complete=True)
|
||||||
|
logger.info("Model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Model warmup failed: {e}")
|
||||||
|
|
||||||
|
def create_tokenizer(self, language=None):
|
||||||
|
self.tokenizer = tokenizer.get_tokenizer(
|
||||||
|
multilingual=self.tokenizer_is_multilingual,
|
||||||
|
language=language,
|
||||||
|
num_languages=self.model.num_languages,
|
||||||
|
task=self.decode_options.task,
|
||||||
|
)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
def trim_context(self):
|
||||||
|
logger.info("Trimming context")
|
||||||
|
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
||||||
|
logger.info(f"Context text: {self.state.context.as_text()}")
|
||||||
|
l = sum(t.shape[1] for t in self.state.tokens) + c
|
||||||
|
after = 0 if self.cfg.static_init_prompt is None else len(self.cfg.static_init_prompt)
|
||||||
|
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||||
|
t = self.state.context.trim_words(after=after)
|
||||||
|
l -= t
|
||||||
|
c -= t
|
||||||
|
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
if t == 0:
|
||||||
|
break
|
||||||
|
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
||||||
|
|
||||||
|
def refresh_segment(self, complete=False):
|
||||||
|
logger.debug("Refreshing segment:")
|
||||||
|
self.init_tokens()
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.state.cumulative_time_offset = 0.0
|
||||||
|
self.init_context()
|
||||||
|
logger.debug(f"Context: {self.state.context}")
|
||||||
|
if not complete and len(self.state.segments) > 2:
|
||||||
|
self.state.segments = self.state.segments[-2:]
|
||||||
|
else:
|
||||||
|
logger.debug("removing all segments.")
|
||||||
|
self.state.segments = []
|
||||||
|
self.state.log_segments += 1
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
|
||||||
|
def segments_len(self):
|
||||||
|
return sum(s.shape[0] for s in self.state.segments) / 16000
|
||||||
|
|
||||||
|
def _apply_minseglen(self):
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
if segments_len < self.cfg.audio_min_len:
|
||||||
|
logger.debug("waiting for next segment")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
self.state.clean_cache()
|
||||||
|
|
||||||
|
def debug_print_tokens(self, tokens):
|
||||||
|
for i in range(min(self.cfg.beam_size, tokens.shape[0])):
|
||||||
|
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
||||||
|
|
||||||
|
# === Language detection ===
|
||||||
|
|
||||||
|
def _detect_language_if_needed(self, encoder_feature):
|
||||||
|
if (
|
||||||
|
self.cfg.language == "auto"
|
||||||
|
and self.state.detected_language is None
|
||||||
|
and self.state.first_timestamp
|
||||||
|
):
|
||||||
|
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
||||||
|
if seconds_since_start >= 2.0:
|
||||||
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
|
self.create_tokenizer(top_lan)
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.state.cumulative_time_offset = 0.0
|
||||||
|
self.init_tokens()
|
||||||
|
self.init_context()
|
||||||
|
self.state.detected_language = top_lan
|
||||||
|
logger.info(f"Tokenizer language: {self.tokenizer.language}")
|
||||||
|
|
||||||
|
# === Template infer() ===
|
||||||
|
|
||||||
|
def infer(self, is_last=False):
|
||||||
|
"""Main inference — template method calling abstract hooks for tensor ops."""
|
||||||
|
new_segment = True
|
||||||
|
|
||||||
|
if len(self.state.segments) == 0:
|
||||||
|
logger.debug("No segments, nothing to do")
|
||||||
|
return []
|
||||||
|
if not self._apply_minseglen():
|
||||||
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
input_segments = self._concat_segments()
|
||||||
|
encoder_feature, content_mel_len = self._encode(input_segments)
|
||||||
|
self._evaluate(encoder_feature)
|
||||||
|
|
||||||
|
self._detect_language_if_needed(encoder_feature)
|
||||||
|
self.trim_context()
|
||||||
|
current_tokens = self._current_tokens()
|
||||||
|
|
||||||
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
sum_logprobs = self._init_sum_logprobs()
|
||||||
|
completed = False
|
||||||
|
token_len_before = current_tokens.shape[1]
|
||||||
|
l_absolute_timestamps = []
|
||||||
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
|
audio_duration_s = self.segments_len()
|
||||||
|
max_tokens = max(50, int(audio_duration_s * 15 * 1.5))
|
||||||
|
tokens_produced = 0
|
||||||
|
most_attended_frame = None
|
||||||
|
|
||||||
|
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||||
|
tokens_produced += 1
|
||||||
|
if tokens_produced > max_tokens:
|
||||||
|
logger.warning(
|
||||||
|
f"[Loop Detection] Too many tokens ({tokens_produced}) "
|
||||||
|
f"for {audio_duration_s:.2f}s audio. Breaking."
|
||||||
|
)
|
||||||
|
current_tokens = current_tokens[:, :token_len_before]
|
||||||
|
break
|
||||||
|
|
||||||
|
tokens_for_logits = current_tokens if new_segment else current_tokens[:, -1:]
|
||||||
|
logits, cross_attns = self._get_logits_and_cross_attn(
|
||||||
|
tokens_for_logits, encoder_feature
|
||||||
|
)
|
||||||
|
self._evaluate(logits)
|
||||||
|
|
||||||
|
accumulated_cross_attns.append(cross_attns)
|
||||||
|
if len(accumulated_cross_attns) > 16:
|
||||||
|
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||||
|
|
||||||
|
if new_segment and self._check_no_speech(logits):
|
||||||
|
break
|
||||||
|
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if new_segment:
|
||||||
|
logits = self._suppress_blank_tokens(logits)
|
||||||
|
new_segment = False
|
||||||
|
|
||||||
|
logits = self._apply_token_suppression(logits)
|
||||||
|
logits = self._apply_dry_penalty(logits, current_tokens)
|
||||||
|
current_tokens, completed = self._update_tokens(
|
||||||
|
current_tokens, logits, sum_logprobs
|
||||||
|
)
|
||||||
|
self._evaluate(current_tokens)
|
||||||
|
|
||||||
|
logger.debug(f"Decoding completed: {completed}")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
|
attn = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
|
||||||
|
frames_list, most_attended_frame = self._get_attended_frames(attn)
|
||||||
|
|
||||||
|
absolute_timestamps = [
|
||||||
|
(frame * 0.02 + self.state.cumulative_time_offset)
|
||||||
|
for frame in frames_list
|
||||||
|
]
|
||||||
|
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||||
|
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
|
||||||
|
|
||||||
|
if completed:
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Rewind check
|
||||||
|
if (
|
||||||
|
not is_last
|
||||||
|
and self.state.last_attend_frame - most_attended_frame
|
||||||
|
> self.cfg.rewind_threshold
|
||||||
|
):
|
||||||
|
if current_tokens.shape[1] > 1 and self._is_special_token(current_tokens):
|
||||||
|
logger.debug("omit rewinding from special tokens")
|
||||||
|
self.state.last_attend_frame = most_attended_frame
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[rewind detected] current: {most_attended_frame}, "
|
||||||
|
f"last: {self.state.last_attend_frame}"
|
||||||
|
)
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
current_tokens = self._rewind_tokens()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.state.last_attend_frame = most_attended_frame
|
||||||
|
|
||||||
|
if content_mel_len - most_attended_frame <= (
|
||||||
|
4 if is_last else self.cfg.frame_threshold
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"attention reaches the end: {most_attended_frame}/{content_mel_len}"
|
||||||
|
)
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Post-decode: split tokens and build timestamped words
|
||||||
|
tokens_to_split = self._tokens_to_list(current_tokens, token_len_before)
|
||||||
|
if self.state.pending_incomplete_tokens:
|
||||||
|
logger.debug(
|
||||||
|
f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} "
|
||||||
|
f"pending tokens: {self.state.pending_incomplete_tokens}"
|
||||||
|
)
|
||||||
|
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
|
||||||
|
|
||||||
|
new_hypothesis, split_words, split_tokens = self._split_tokens(
|
||||||
|
tokens_to_split, fire_detected, is_last
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tokens_tensor = self._make_new_tokens_tensor(new_hypothesis)
|
||||||
|
self.state.tokens.append(new_tokens_tensor)
|
||||||
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
|
||||||
|
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
||||||
|
self.state.first_timestamp = l_absolute_timestamps[0]
|
||||||
|
|
||||||
|
timestamped_words = self._build_timestamped_words(
|
||||||
|
split_words, split_tokens, l_absolute_timestamps
|
||||||
|
)
|
||||||
|
self._handle_pending_tokens(split_words, split_tokens)
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
|
|
||||||
|
# === Post-decode shared helpers ===
|
||||||
|
|
||||||
|
def _split_tokens(self, tokens_list, fire_detected, is_last):
|
||||||
|
"""Split token list into words. Returns (hypothesis, split_words, split_tokens)."""
|
||||||
|
if fire_detected or is_last:
|
||||||
|
new_hypothesis = tokens_list
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
|
else:
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_list)
|
||||||
|
if len(split_words) > 1:
|
||||||
|
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||||
|
else:
|
||||||
|
new_hypothesis = []
|
||||||
|
return new_hypothesis, split_words, split_tokens
|
||||||
|
|
||||||
|
def _build_timestamped_words(self, split_words, split_tokens, l_absolute_timestamps):
|
||||||
|
"""Build list of timestamped ASRToken from split words."""
|
||||||
|
timestamped_words = []
|
||||||
|
timestamp_idx = 0
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
if replacement_char in word:
|
||||||
|
cleaned = word.replace(replacement_char, "")
|
||||||
|
if not cleaned.strip():
|
||||||
|
logger.debug(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
logger.debug(f"[UTF-8 Filter] Cleaned {repr(word)} -> {repr(cleaned)}")
|
||||||
|
word = cleaned
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
|
except IndexError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timestamp index {timestamp_idx} out of range, using last timestamp"
|
||||||
|
)
|
||||||
|
current_timestamp = (
|
||||||
|
l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
|
||||||
|
)
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
|
timestamp_entry = ASRToken(
|
||||||
|
start=round(current_timestamp, 2),
|
||||||
|
end=round(current_timestamp + 0.1, 2),
|
||||||
|
text=word,
|
||||||
|
speaker=self.state.speaker,
|
||||||
|
detected_language=self.state.detected_language,
|
||||||
|
).with_offset(self.state.global_time_offset)
|
||||||
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
|
|
||||||
|
def _handle_pending_tokens(self, split_words, split_tokens):
|
||||||
|
"""Handle incomplete UTF-8 tokens for next chunk."""
|
||||||
|
MAX_PENDING_TOKENS = 10
|
||||||
|
MAX_PENDING_RETRIES = 2
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
if split_words and replacement_char in split_words[-1]:
|
||||||
|
self.state.pending_retries += 1
|
||||||
|
if self.state.pending_retries > MAX_PENDING_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
f"[UTF-8 Fix] Dropping {len(split_tokens[-1])} incomplete tokens "
|
||||||
|
f"after {MAX_PENDING_RETRIES} retries (won't resolve)"
|
||||||
|
)
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
elif len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
||||||
|
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||||
|
logger.debug(
|
||||||
|
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
|
||||||
|
f"incomplete tokens for next chunk (retry {self.state.pending_retries})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
|
||||||
|
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
|
||||||
|
)
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
else:
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
self.state.pending_retries = 0
|
||||||
|
|
||||||
|
# === Repetition penalty ===
|
||||||
|
|
||||||
|
def _apply_dry_penalty(self, logits, current_tokens):
|
||||||
|
"""DRY penalty v0: penalize tokens that would extend a verbatim repetition.
|
||||||
|
See https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||||
|
|
||||||
|
Scans the decoded sequence for positions where the current suffix already
|
||||||
|
appeared --> for each such match, the token that followed it in the past is
|
||||||
|
penalised exponentially with the match length
|
||||||
|
"""
|
||||||
|
eot = self.tokenizer.eot
|
||||||
|
seq = current_tokens[0].tolist()
|
||||||
|
if len(seq) < 5:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
last = seq[-1]
|
||||||
|
if last >= eot:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
penalties = {}
|
||||||
|
for i in range(len(seq) - 2, -1, -1):
|
||||||
|
if seq[i] != last:
|
||||||
|
continue
|
||||||
|
next_tok = seq[i + 1]
|
||||||
|
if next_tok >= eot:
|
||||||
|
continue
|
||||||
|
|
||||||
|
length = 1
|
||||||
|
while length < 50:
|
||||||
|
j, k = i - length, len(seq) - 1 - length
|
||||||
|
if j < 0 or k <= i:
|
||||||
|
break
|
||||||
|
if seq[j] != seq[k] or seq[j] >= eot:
|
||||||
|
break
|
||||||
|
length += 1
|
||||||
|
|
||||||
|
if next_tok not in penalties or length > penalties[next_tok]:
|
||||||
|
penalties[next_tok] = length
|
||||||
|
|
||||||
|
if penalties:
|
||||||
|
max_len = max(penalties.values())
|
||||||
|
if max_len >= 4:
|
||||||
|
logger.debug(f"[DRY] penalising {len(penalties)} tokens (longest match: {max_len})")
|
||||||
|
for tok, length in penalties.items():
|
||||||
|
if length >= 2:
|
||||||
|
logits[:, tok] = logits[:, tok] - 1.0 * 2.0 ** (length - 2)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# === Abstract methods — subclass must implement ===
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_state(self, cfg: AlignAttConfig):
|
||||||
|
"""Initialize per-session decoder state."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_tokens(self):
|
||||||
|
"""Initialize token sequence with framework-specific tensors."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_context(self):
|
||||||
|
"""Initialize context buffer with framework-specific TokenBuffer."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert_audio(self, segment=None):
|
||||||
|
"""Insert audio segment into buffer."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _current_tokens(self):
|
||||||
|
"""Build current token tensor for decoding."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fire_at_boundary(self, feature):
|
||||||
|
"""Check if we should fire at word boundary."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def lang_id(self, encoder_features):
|
||||||
|
"""Language detection from encoder features. Returns (tokens, probs)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _concat_segments(self):
|
||||||
|
"""Concatenate audio segments into single array/tensor."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _encode(self, input_segments):
|
||||||
|
"""Encode audio. Returns (encoder_feature, content_mel_len)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_sum_logprobs(self):
|
||||||
|
"""Create zero sum_logprobs tensor for beam search."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||||
|
"""Get logits and cross-attention from decoder. Returns (logits, cross_attns)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _check_no_speech(self, logits):
|
||||||
|
"""Check no_speech probability at start of segment. Returns True to break."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _suppress_blank_tokens(self, logits):
|
||||||
|
"""Suppress blank/EOT tokens at segment start. Returns modified logits."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _apply_token_suppression(self, logits):
|
||||||
|
"""Apply general token suppression. Returns modified logits."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||||
|
"""Update tokens via decoder. Returns (current_tokens, completed)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_cross_attention(self, accumulated_cross_attns, content_mel_len):
|
||||||
|
"""Process cross-attention for alignment. Returns attention tensor."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_attended_frames(self, attn):
|
||||||
|
"""Get most attended frames. Returns (frames_as_python_list, first_frame_int)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _is_special_token(self, current_tokens):
|
||||||
|
"""Check if second-to-last token is a special token (>= DEC_PAD)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _rewind_tokens(self):
|
||||||
|
"""Concatenate state tokens for rewind. Returns token tensor."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _tokens_to_list(self, current_tokens, start_col):
|
||||||
|
"""Extract tokens as Python list from start_col onwards."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _make_new_tokens_tensor(self, hypothesis):
|
||||||
|
"""Create tensor from hypothesis token list, repeated for beam search."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _evaluate(self, tensor):
|
||||||
|
"""Evaluate lazy tensor (mx.eval for MLX, no-op for PyTorch)."""
|
||||||
|
...
|
||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
mlx_backend_available)
|
mlx_backend_available)
|
||||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||||
@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||||
|
from .mlx import MLXAlignAtt
|
||||||
else:
|
else:
|
||||||
mlx_model_mapping = {}
|
mlx_model_mapping = {}
|
||||||
|
MLXAlignAtt = None
|
||||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||||
if HAS_FASTER_WHISPER:
|
if HAS_FASTER_WHISPER:
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
@@ -36,27 +38,26 @@ else:
|
|||||||
MIN_DURATION_REAL_SILENCE = 5
|
MIN_DURATION_REAL_SILENCE = 5
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
class SimulStreamingOnlineProcessor:
|
||||||
|
"""Online processor for SimulStreaming ASR."""
|
||||||
SAMPLING_RATE = 16000
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, asr, logfile=sys.stderr):
|
||||||
self,
|
|
||||||
asr,
|
|
||||||
logfile=sys.stderr,
|
|
||||||
):
|
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.committed: List[ASRToken] = []
|
self.model = self._create_alignatt()
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.load_new_alignatt_instance()
|
|
||||||
|
|
||||||
if asr.tokenizer:
|
if asr.tokenizer:
|
||||||
self.model.tokenizer = asr.tokenizer
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
self.model.state.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
def load_new_alignatt_instance(self):
|
def _create_alignatt(self):
|
||||||
"""Initialize AlignAtt decoder using the shared model."""
|
"""Create the AlignAtt decoder instance based on ASR mode."""
|
||||||
self.model = AlignAtt(
|
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
|
||||||
|
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
|
||||||
|
else:
|
||||||
|
return AlignAtt(
|
||||||
cfg=self.asr.cfg,
|
cfg=self.asr.cfg,
|
||||||
loaded_model=self.asr.shared_model,
|
loaded_model=self.asr.shared_model,
|
||||||
mlx_encoder=self.asr.mlx_encoder,
|
mlx_encoder=self.asr.mlx_encoder,
|
||||||
@@ -68,17 +69,15 @@ class SimulStreamingOnlineProcessor:
|
|||||||
return tokens, processed_upto
|
return tokens, processed_upto
|
||||||
|
|
||||||
def end_silence(self, silence_duration, offset):
|
def end_silence(self, silence_duration, offset):
|
||||||
"""
|
"""Handle silence period."""
|
||||||
Handle silence period.
|
|
||||||
|
|
||||||
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
|
||||||
Otherwise, insert a small silence and shift the last_attend_frame.
|
|
||||||
"""
|
|
||||||
self.end += silence_duration
|
self.end += silence_duration
|
||||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||||
if not long_silence:
|
if not long_silence:
|
||||||
gap_len = int(16000 * silence_duration)
|
gap_len = int(16000 * silence_duration)
|
||||||
if gap_len > 0:
|
if gap_len > 0:
|
||||||
|
if self.asr.use_full_mlx:
|
||||||
|
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||||
|
else:
|
||||||
gap_silence = torch.zeros(gap_len)
|
gap_silence = torch.zeros(gap_len)
|
||||||
self.model.insert_audio(gap_silence)
|
self.model.insert_audio(gap_silence)
|
||||||
if long_silence:
|
if long_silence:
|
||||||
@@ -87,10 +86,11 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
self.end = audio_stream_end_time
|
||||||
# Convert numpy array to torch tensor
|
if self.asr.use_full_mlx:
|
||||||
|
self.model.insert_audio(audio)
|
||||||
|
else:
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
|
||||||
self.model.insert_audio(audio_tensor)
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
@@ -120,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.buffer.extend(timestamped_words)
|
self.buffer.extend(timestamped_words)
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
self.committed.extend(timestamped_words)
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
return timestamped_words, self.end
|
return timestamped_words, self.end
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -130,6 +129,10 @@ class SimulStreamingOnlineProcessor:
|
|||||||
def warmup(self, audio, init_prompt=""):
|
def warmup(self, audio, init_prompt=""):
|
||||||
"""Warmup the SimulStreaming model."""
|
"""Warmup the SimulStreaming model."""
|
||||||
try:
|
try:
|
||||||
|
if self.asr.use_full_mlx:
|
||||||
|
# MLX mode: ensure numpy array
|
||||||
|
if hasattr(audio, 'numpy'):
|
||||||
|
audio = audio.numpy()
|
||||||
self.model.insert_audio(audio)
|
self.model.insert_audio(audio)
|
||||||
self.model.infer(True)
|
self.model.infer(True)
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
@@ -139,9 +142,14 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
|
||||||
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
class SimulStreamingASR():
|
|
||||||
|
class SimulStreamingASR:
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
@@ -158,35 +166,25 @@ class SimulStreamingASR():
|
|||||||
self.fast_encoder = False
|
self.fast_encoder = False
|
||||||
self._resolved_model_path = None
|
self._resolved_model_path = None
|
||||||
self.encoder_backend = "whisper"
|
self.encoder_backend = "whisper"
|
||||||
|
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
||||||
preferred_backend = getattr(self, "backend", "auto")
|
preferred_backend = getattr(self, "backend", "auto")
|
||||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
resolved_model_path = resolve_model_path(self.model_path)
|
resolved_model_path = resolve_model_path(self.model_path)
|
||||||
self._resolved_model_path = resolved_model_path
|
self._resolved_model_path = resolved_model_path
|
||||||
self.model_path = str(resolved_model_path)
|
self.model_path = str(resolved_model_path)
|
||||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
|
||||||
if self.pytorch_path:
|
model_info = detect_model_format(resolved_model_path)
|
||||||
self.model_name = self.pytorch_path.stem
|
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||||
else:
|
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||||
self.model_name = Path(self.model_path).stem
|
|
||||||
|
if not self.use_full_mlx and not model_info.has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||||
)
|
)
|
||||||
|
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
|
||||||
elif self.model_size is not None:
|
elif self.model_size is not None:
|
||||||
model_mapping = {
|
|
||||||
'tiny': './tiny.pt',
|
|
||||||
'base': './base.pt',
|
|
||||||
'small': './small.pt',
|
|
||||||
'medium': './medium.pt',
|
|
||||||
'medium.en': './medium.en.pt',
|
|
||||||
'large-v1': './large-v1.pt',
|
|
||||||
'base.en': './base.en.pt',
|
|
||||||
'small.en': './small.en.pt',
|
|
||||||
'tiny.en': './tiny.en.pt',
|
|
||||||
'large-v2': './large-v2.pt',
|
|
||||||
'large-v3': './large-v3.pt',
|
|
||||||
'large': './large-v3.pt'
|
|
||||||
}
|
|
||||||
self.model_name = self.model_size
|
self.model_name = self.model_size
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||||
@@ -202,6 +200,13 @@ class SimulStreamingASR():
|
|||||||
if self.encoder_backend == "whisper":
|
if self.encoder_backend == "whisper":
|
||||||
self.disable_fast_encoder = True
|
self.disable_fast_encoder = True
|
||||||
|
|
||||||
|
# MLX full decoder disabled by default — MLXAlignAtt has known issues
|
||||||
|
# with token generation after punctuation. Users can opt-in with
|
||||||
|
# --use-full-mlx if they want to test it.
|
||||||
|
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||||
|
# if not hasattr(self, '_full_mlx_disabled'):
|
||||||
|
# self.use_full_mlx = True
|
||||||
|
|
||||||
self.cfg = AlignAttConfig(
|
self.cfg = AlignAttConfig(
|
||||||
tokenizer_is_multilingual= is_multilingual,
|
tokenizer_is_multilingual= is_multilingual,
|
||||||
segment_length=self.min_chunk_size,
|
segment_length=self.min_chunk_size,
|
||||||
@@ -212,7 +217,7 @@ class SimulStreamingASR():
|
|||||||
cif_ckpt_path=self.cif_ckpt_path,
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
decoder_type="beam",
|
decoder_type="beam",
|
||||||
beam_size=self.beams,
|
beam_size=self.beams,
|
||||||
task=self.direct_english_translation,
|
task="translate" if self.direct_english_translation else "transcribe",
|
||||||
never_fire=self.never_fire,
|
never_fire=self.never_fire,
|
||||||
init_prompt=self.init_prompt,
|
init_prompt=self.init_prompt,
|
||||||
max_context_tokens=self.max_context_tokens,
|
max_context_tokens=self.max_context_tokens,
|
||||||
@@ -225,20 +230,36 @@ class SimulStreamingASR():
|
|||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
self.mlx_encoder, self.fw_encoder = None, None
|
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
||||||
if self.encoder_backend == "mlx-whisper":
|
self.shared_model = None
|
||||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
|
||||||
|
if self.use_full_mlx and HAS_MLX_WHISPER:
|
||||||
|
logger.info('MLX Whisper backend used.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
mlx_model = str(self._resolved_model_path)
|
mlx_model_path = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||||
if not mlx_model:
|
if not mlx_model_path:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
)
|
)
|
||||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
|
||||||
|
self._warmup_mlx_model()
|
||||||
|
elif self.encoder_backend == "mlx-whisper":
|
||||||
|
# hybrid mode: mlx encoder + pytorch decoder
|
||||||
|
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
|
||||||
|
if self._resolved_model_path is not None:
|
||||||
|
mlx_model_path = str(self._resolved_model_path)
|
||||||
|
else:
|
||||||
|
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||||
|
if not mlx_model_path:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
|
)
|
||||||
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||||
|
self.shared_model = self.load_model()
|
||||||
elif self.encoder_backend == "faster-whisper":
|
elif self.encoder_backend == "faster-whisper":
|
||||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
fw_model = str(self._resolved_model_path)
|
fw_model = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
@@ -249,6 +270,19 @@ class SimulStreamingASR():
|
|||||||
compute_type='auto',
|
compute_type='auto',
|
||||||
)
|
)
|
||||||
self.shared_model = self.load_model()
|
self.shared_model = self.load_model()
|
||||||
|
else:
|
||||||
|
self.shared_model = self.load_model()
|
||||||
|
|
||||||
|
def _warmup_mlx_model(self):
|
||||||
|
"""Warmup the full MLX model."""
|
||||||
|
warmup_audio = load_file(self.warmup_file)
|
||||||
|
if warmup_audio is not None:
|
||||||
|
temp_model = MLXAlignAtt(
|
||||||
|
cfg=self.cfg,
|
||||||
|
mlx_model=self.mlx_model,
|
||||||
|
)
|
||||||
|
temp_model.warmup(warmup_audio)
|
||||||
|
logger.info("Full MLX model warmed up successfully")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||||
@@ -292,11 +326,14 @@ class SimulStreamingASR():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
|
||||||
|
lora_path = getattr(self, 'lora_path', None)
|
||||||
whisper_model = load_model(
|
whisper_model = load_model(
|
||||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
name=model_ref,
|
||||||
download_root=self.model_path,
|
download_root=getattr(self, 'model_cache_dir', None),
|
||||||
decoder_only=self.fast_encoder,
|
decoder_only=self.fast_encoder,
|
||||||
custom_alignment_heads=self.custom_alignment_heads
|
custom_alignment_heads=self.custom_alignment_heads,
|
||||||
|
lora_path=lora_path,
|
||||||
)
|
)
|
||||||
warmup_audio = load_file(self.warmup_file)
|
warmup_audio = load_file(self.warmup_file)
|
||||||
if warmup_audio is not None:
|
if warmup_audio is not None:
|
||||||
@@ -316,7 +353,7 @@ class SimulStreamingASR():
|
|||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
"""Set up translation task."""
|
"""Set up translation task."""
|
||||||
if self.cfg.language == 'auto':
|
if self.cfg.language == 'auto':
|
||||||
raise Exception('Translation cannot be done with language = auto')
|
raise ValueError('Translation cannot be done with language = auto')
|
||||||
return tokenizer.get_tokenizer(
|
return tokenizer.get_tokenizer(
|
||||||
multilingual=True,
|
multilingual=True,
|
||||||
language=self.cfg.language,
|
language=self.cfg.language,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class DecoderState:
|
|||||||
context: Any = None
|
context: Any = None
|
||||||
|
|
||||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
pending_retries: int = 0
|
||||||
|
|
||||||
global_time_offset: float = 0.0
|
global_time_offset: float = 0.0
|
||||||
cumulative_time_offset: float = 0.0
|
cumulative_time_offset: float = 0.0
|
||||||
@@ -47,9 +48,24 @@ class DecoderState:
|
|||||||
|
|
||||||
def clean_cache(self):
|
def clean_cache(self):
|
||||||
"""Clean the kv_cache after each inference step."""
|
"""Clean the kv_cache after each inference step."""
|
||||||
self.kv_cache = {}
|
# Explicitly delete tensor references to free GPU memory
|
||||||
|
if self.kv_cache:
|
||||||
|
for key in list(self.kv_cache.keys()):
|
||||||
|
tensor = self.kv_cache.pop(key, None)
|
||||||
|
if tensor is not None:
|
||||||
|
del tensor
|
||||||
|
|
||||||
|
# Clear the dict
|
||||||
|
self.kv_cache.clear()
|
||||||
|
|
||||||
|
# Force GPU cache cleanup (only if CUDA is available)
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.decoder_type == "beam" and self.inference is not None:
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
self.inference.kv_cache = self.kv_cache
|
# Create NEW dict instead of sharing reference
|
||||||
|
self.inference.kv_cache = {}
|
||||||
if self.token_decoder is not None:
|
if self.token_decoder is not None:
|
||||||
self.token_decoder.reset()
|
self.token_decoder.reset()
|
||||||
|
|
||||||
@@ -63,6 +79,7 @@ class DecoderState:
|
|||||||
self.last_attend_frame = -rewind_threshold
|
self.last_attend_frame = -rewind_threshold
|
||||||
self.cumulative_time_offset = 0.0
|
self.cumulative_time_offset = 0.0
|
||||||
self.pending_incomplete_tokens = []
|
self.pending_incomplete_tokens = []
|
||||||
|
self.pending_retries = 0
|
||||||
self.log_segments += 1
|
self.log_segments += 1
|
||||||
|
|
||||||
def full_reset(self, rewind_threshold: int = 200):
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
|
|||||||
11
whisperlivekit/simul_whisper/mlx/__init__.py
Normal file
11
whisperlivekit/simul_whisper/mlx/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from .decoder_state import MLXDecoderState
|
||||||
|
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||||
|
from .simul_whisper import MLXAlignAtt
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MLXAlignAtt",
|
||||||
|
"MLXBeamSearchDecoder",
|
||||||
|
"MLXDecoderState",
|
||||||
|
"MLXGreedyDecoder",
|
||||||
|
"MLXInference",
|
||||||
|
]
|
||||||
78
whisperlivekit/simul_whisper/mlx/decoder_state.py
Normal file
78
whisperlivekit/simul_whisper/mlx/decoder_state.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLXDecoderState:
|
||||||
|
"""
|
||||||
|
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
|
||||||
|
where each element is a tuple of mx.arrays.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
||||||
|
|
||||||
|
tokenizer: Any = None
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
reset_tokenizer_to_auto_next_call: bool = False
|
||||||
|
|
||||||
|
tokens: List[mx.array] = field(default_factory=list)
|
||||||
|
initial_tokens: Optional[mx.array] = None
|
||||||
|
initial_token_length: int = 0
|
||||||
|
sot_index: int = 0
|
||||||
|
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||||
|
num_align_heads: int = 0
|
||||||
|
segments: List[np.ndarray] = field(default_factory=list)
|
||||||
|
|
||||||
|
context: Any = None
|
||||||
|
|
||||||
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
pending_retries: int = 0
|
||||||
|
|
||||||
|
global_time_offset: float = 0.0
|
||||||
|
cumulative_time_offset: float = 0.0
|
||||||
|
first_timestamp: Optional[float] = None
|
||||||
|
last_attend_frame: int = 0
|
||||||
|
|
||||||
|
speaker: int = -1
|
||||||
|
log_segments: int = 0
|
||||||
|
cif_weights: Optional[mx.array] = None
|
||||||
|
always_fire: bool = False
|
||||||
|
never_fire: bool = False
|
||||||
|
|
||||||
|
suppress_tokens: Optional[Tuple[int, ...]] = None
|
||||||
|
|
||||||
|
token_decoder: Any = None
|
||||||
|
decoder_type: str = "greedy"
|
||||||
|
|
||||||
|
inference: Any = None
|
||||||
|
|
||||||
|
def clean_cache(self):
|
||||||
|
self.kv_cache = None
|
||||||
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
|
self.inference.kv_cache = None
|
||||||
|
if self.token_decoder is not None:
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
def reset(self, rewind_threshold: int = 200):
|
||||||
|
self.last_attend_frame = -rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
self.pending_retries = 0
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
|
"""
|
||||||
|
Full reset including audio segments and tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewind_threshold: Value for resetting last_attend_frame
|
||||||
|
"""
|
||||||
|
self.reset(rewind_threshold)
|
||||||
|
self.segments = []
|
||||||
|
self.tokens = []
|
||||||
|
self.kv_cache = None
|
||||||
|
self.first_timestamp = None
|
||||||
|
|
||||||
219
whisperlivekit/simul_whisper/mlx/decoders.py
Normal file
219
whisperlivekit/simul_whisper/mlx/decoders.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
MLX-native token decoders for streaming ASR.
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MLXGreedyDecoder:
|
||||||
|
"""Greedy decoder using MLX operations."""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float, eot: int):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.eot = eot
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||||
|
) -> Tuple[mx.array, bool]:
|
||||||
|
"""
|
||||||
|
Update tokens with next predicted token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Current token sequence, shape (batch, seq_len)
|
||||||
|
logits: Logits for next token, shape (batch, vocab_size)
|
||||||
|
sum_logprobs: Cumulative log probabilities, shape (batch,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated tokens and completion flag
|
||||||
|
"""
|
||||||
|
if self.temperature == 0:
|
||||||
|
next_tokens = mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
probs = mx.softmax(logits / self.temperature, axis=-1)
|
||||||
|
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
|
||||||
|
|
||||||
|
logprobs = mx.softmax(logits, axis=-1)
|
||||||
|
logprobs = mx.log(logprobs + 1e-10)
|
||||||
|
batch_size = logprobs.shape[0]
|
||||||
|
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
||||||
|
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
|
||||||
|
sum_logprobs = sum_logprobs + current_logprobs * mask
|
||||||
|
eot_mask = (tokens[:, -1] == self.eot)
|
||||||
|
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
||||||
|
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
||||||
|
completed = bool(mx.all(tokens[:, -1] == self.eot))
|
||||||
|
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||||
|
"""Finalize decoding by ensuring EOT at end."""
|
||||||
|
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
|
||||||
|
tokens = mx.concatenate([tokens, eot_column], axis=1)
|
||||||
|
return tokens, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class MLXBeamSearchDecoder:
|
||||||
|
"""Beam search decoder using MLX operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beam_size: int,
|
||||||
|
eot: int,
|
||||||
|
inference: Any,
|
||||||
|
patience: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.eot = eot
|
||||||
|
self.inference = inference
|
||||||
|
self.patience = patience or 1.0
|
||||||
|
self.max_candidates: int = round(beam_size * self.patience)
|
||||||
|
self.finished_sequences: Optional[List[Dict]] = None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.max_candidates > 0
|
||||||
|
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset finished sequences for new segment."""
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||||
|
) -> Tuple[mx.array, bool]:
|
||||||
|
"""
|
||||||
|
Update tokens using beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Current token sequences, shape (batch * beam_size, seq_len)
|
||||||
|
logits: Logits for next token, shape (batch * beam_size, vocab_size)
|
||||||
|
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated tokens and completion flag
|
||||||
|
"""
|
||||||
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
|
if self.finished_sequences is None:
|
||||||
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
logprobs = mx.softmax(logits, axis=-1)
|
||||||
|
logprobs = mx.log(logprobs + 1e-10)
|
||||||
|
logprobs_np = np.array(logprobs)
|
||||||
|
tokens_np = np.array(tokens)
|
||||||
|
sum_logprobs_np = np.array(sum_logprobs)
|
||||||
|
|
||||||
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
|
new_sum_logprobs = []
|
||||||
|
|
||||||
|
for i in range(n_audio):
|
||||||
|
scores, sources, finished = {}, {}, {}
|
||||||
|
for j in range(self.beam_size):
|
||||||
|
idx = i * self.beam_size + j
|
||||||
|
prefix = tokens_np[idx].tolist()
|
||||||
|
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
|
||||||
|
|
||||||
|
for token_idx in top_k_indices:
|
||||||
|
logprob = logprobs_np[idx, token_idx]
|
||||||
|
new_logprob = sum_logprobs_np[idx] + logprob
|
||||||
|
sequence = tuple(prefix + [int(token_idx)])
|
||||||
|
scores[sequence] = new_logprob
|
||||||
|
sources[sequence] = idx
|
||||||
|
saved = 0
|
||||||
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
|
if sequence[-1] == self.eot:
|
||||||
|
finished[sequence] = scores[sequence]
|
||||||
|
else:
|
||||||
|
new_sum_logprobs.append(scores[sequence])
|
||||||
|
next_tokens.append(sequence)
|
||||||
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
|
saved += 1
|
||||||
|
if saved == self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
finished_sequences.append(finished)
|
||||||
|
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
|
||||||
|
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
||||||
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
|
for previously_finished, newly_finished in zip(
|
||||||
|
self.finished_sequences, finished_sequences
|
||||||
|
):
|
||||||
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
|
if len(previously_finished) >= self.max_candidates:
|
||||||
|
break
|
||||||
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
completed = all(
|
||||||
|
len(sequences) >= self.max_candidates
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
|
||||||
|
"""Finalize beam search by selecting best sequences."""
|
||||||
|
preceding_tokens_np = np.array(preceding_tokens)
|
||||||
|
sum_logprobs_np = np.array(sum_logprobs)
|
||||||
|
|
||||||
|
n_audio = preceding_tokens_np.shape[0] // self.beam_size
|
||||||
|
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
|
||||||
|
sum_logprobs_list: List[float] = [0.0] * n_audio
|
||||||
|
|
||||||
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
|
if sequences:
|
||||||
|
best_seq = max(sequences, key=sequences.get)
|
||||||
|
tokens_list[i] = list(best_seq)
|
||||||
|
sum_logprobs_list[i] = sequences[best_seq]
|
||||||
|
else:
|
||||||
|
idx = i * self.beam_size
|
||||||
|
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
|
||||||
|
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
|
||||||
|
max_len = max(len(t) for t in tokens_list)
|
||||||
|
for i, t in enumerate(tokens_list):
|
||||||
|
tokens_list[i] = t + [self.eot] * (max_len - len(t))
|
||||||
|
|
||||||
|
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
|
||||||
|
return tokens, sum_logprobs_list
|
||||||
|
|
||||||
|
|
||||||
|
class MLXInference:
|
||||||
|
"""MLX inference wrapper for beam search KV cache management."""
|
||||||
|
|
||||||
|
def __init__(self, model, initial_token_length: int):
|
||||||
|
self.model = model
|
||||||
|
self.initial_token_length = initial_token_length
|
||||||
|
self.kv_cache = None
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices: List[int]):
|
||||||
|
"""Rearrange KV cache based on beam search source indices."""
|
||||||
|
if self.kv_cache is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if source_indices == list(range(len(source_indices))):
|
||||||
|
return
|
||||||
|
|
||||||
|
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
|
||||||
|
|
||||||
|
new_cache = []
|
||||||
|
for layer_cache in self.kv_cache:
|
||||||
|
(k, v), (cross_k, cross_v) = layer_cache
|
||||||
|
new_k = k[source_indices_mx]
|
||||||
|
new_v = v[source_indices_mx]
|
||||||
|
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
|
||||||
|
|
||||||
|
self.kv_cache = new_cache
|
||||||
|
|
||||||
|
def logits(
|
||||||
|
self,
|
||||||
|
tokens: mx.array,
|
||||||
|
audio_features: mx.array,
|
||||||
|
) -> Tuple[mx.array, List]:
|
||||||
|
"""Get logits from decoder with KV cache."""
|
||||||
|
logits, self.kv_cache, cross_qk = self.model.decoder(
|
||||||
|
tokens, audio_features, kv_cache=self.kv_cache
|
||||||
|
)
|
||||||
|
return logits, cross_qk
|
||||||
|
|
||||||
421
whisperlivekit/simul_whisper/mlx/simul_whisper.py
Normal file
421
whisperlivekit/simul_whisper/mlx/simul_whisper.py
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
"""MLX whisper AlignAtt streaming decoder."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
|
||||||
|
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
|
||||||
|
|
||||||
|
from ..align_att_base import DEC_PAD, AlignAttBase
|
||||||
|
from ..config import AlignAttConfig
|
||||||
|
from .decoder_state import MLXDecoderState
|
||||||
|
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLXTokenBuffer:
|
||||||
|
"""Token buffer for MLX-based decoding."""
|
||||||
|
|
||||||
|
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
|
||||||
|
self.text = text
|
||||||
|
self.prefix_token_ids = prefix_token_ids or []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
|
def as_token_ids(self, tokenizer=None):
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
if tokenizer is None:
|
||||||
|
raise ValueError("Tokenizer is not set.")
|
||||||
|
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||||
|
|
||||||
|
def as_mlx_array(self) -> mx.array:
|
||||||
|
tok_ids = self.as_token_ids()
|
||||||
|
return mx.array([tok_ids], dtype=mx.int32)
|
||||||
|
|
||||||
|
def as_mlx_array_beam(self, beam: int) -> mx.array:
|
||||||
|
t = self.as_mlx_array()
|
||||||
|
return mx.repeat(t, beam, axis=0)
|
||||||
|
|
||||||
|
def as_text(self):
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty(*a, **kw):
|
||||||
|
return MLXTokenBuffer(*a, **kw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_text(text, *a, **kw):
|
||||||
|
return MLXTokenBuffer(*a, text=text, **kw)
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return self.text is None or self.text == ""
|
||||||
|
|
||||||
|
def trim_words(self, num=1, after=0):
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
|
ids = tokenizer.encode(self.text[after:])
|
||||||
|
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||||
|
if not words:
|
||||||
|
return 0
|
||||||
|
self.text = self.text[:after] + "".join(words[num:])
|
||||||
|
return sum(len(wi) for wi in wids[:num])
|
||||||
|
|
||||||
|
def append_token_ids(self, token_ids):
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
|
all_tokens = self.pending_token_ids + token_ids
|
||||||
|
decoded = tokenizer.decode(all_tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
if replacement_char in decoded:
|
||||||
|
if len(all_tokens) > 1:
|
||||||
|
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||||
|
if replacement_char not in decoded_partial:
|
||||||
|
self.text += decoded_partial
|
||||||
|
self.pending_token_ids = [all_tokens[-1]]
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.text += decoded
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
|
||||||
|
"""Apply median filter along the last axis."""
|
||||||
|
if filter_width <= 1:
|
||||||
|
return x
|
||||||
|
pad_width = filter_width // 2
|
||||||
|
shape = x.shape
|
||||||
|
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
|
||||||
|
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
|
||||||
|
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
|
||||||
|
result = []
|
||||||
|
for i in range(shape[-1]):
|
||||||
|
window = x_padded[..., i:i + filter_width]
|
||||||
|
sorted_window = mx.sort(window, axis=-1)
|
||||||
|
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
|
||||||
|
result.append(median_val)
|
||||||
|
return mx.concatenate(result, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MLXAlignAtt(AlignAttBase):
|
||||||
|
"""
|
||||||
|
MLX-native Alignment-based Attention decoder for SimulStreaming.
|
||||||
|
|
||||||
|
Runs entirely on MLX, with no PyTorch dependencies for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: AlignAttConfig,
|
||||||
|
mlx_model: Any,
|
||||||
|
) -> None:
|
||||||
|
# Common init (sets self.model, self.cfg, decode_options, etc.)
|
||||||
|
self._base_init(cfg, mlx_model)
|
||||||
|
logger.info(f"MLX Model dimensions: {self.model.dims}")
|
||||||
|
|
||||||
|
# Per-session state
|
||||||
|
self.state = MLXDecoderState()
|
||||||
|
self._init_state(cfg)
|
||||||
|
|
||||||
|
def _init_state(self, cfg: AlignAttConfig):
|
||||||
|
self._init_state_common(cfg)
|
||||||
|
|
||||||
|
# CIF: MLX doesn't support CIF checkpoint loading
|
||||||
|
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
||||||
|
if cfg.never_fire:
|
||||||
|
self.state.never_fire = True
|
||||||
|
self.state.always_fire = False
|
||||||
|
else:
|
||||||
|
self.state.always_fire = True
|
||||||
|
self.state.never_fire = False
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"CIF checkpoint provided but MLX CIF not implemented. "
|
||||||
|
"Using always_fire=True"
|
||||||
|
)
|
||||||
|
self.state.always_fire = True
|
||||||
|
self.state.never_fire = cfg.never_fire
|
||||||
|
|
||||||
|
self._build_alignment_source()
|
||||||
|
|
||||||
|
# Suppress tokens
|
||||||
|
suppress_tokens = [
|
||||||
|
self.tokenizer.transcribe, self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot, self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
|
||||||
|
] + list(self.tokenizer.all_language_tokens)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||||
|
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
|
||||||
|
|
||||||
|
self.init_tokens()
|
||||||
|
self.init_context()
|
||||||
|
|
||||||
|
# Decoder type
|
||||||
|
self.state.decoder_type = cfg.decoder_type
|
||||||
|
if cfg.decoder_type == "greedy":
|
||||||
|
logger.info("Using MLX greedy decoder")
|
||||||
|
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
|
||||||
|
elif cfg.decoder_type == "beam":
|
||||||
|
logger.info("Using MLX beam decoder")
|
||||||
|
self.state.inference = MLXInference(
|
||||||
|
self.model, self.state.initial_token_length,
|
||||||
|
)
|
||||||
|
self.state.token_decoder = MLXBeamSearchDecoder(
|
||||||
|
inference=self.state.inference,
|
||||||
|
eot=self.tokenizer.eot,
|
||||||
|
beam_size=cfg.beam_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_alignment_source(self):
|
||||||
|
"""Build alignment source mapping from model's alignment_heads."""
|
||||||
|
self.state.align_source = {}
|
||||||
|
self.state.num_align_heads = 0
|
||||||
|
alignment_heads = self.model.alignment_heads
|
||||||
|
if alignment_heads is None:
|
||||||
|
logger.warning("No alignment heads found in model")
|
||||||
|
return
|
||||||
|
if hasattr(alignment_heads, 'tolist'):
|
||||||
|
heads_list = alignment_heads.tolist()
|
||||||
|
else:
|
||||||
|
heads_list = np.array(alignment_heads).tolist()
|
||||||
|
for layer_rank, head_id in heads_list:
|
||||||
|
layer_rank = int(layer_rank)
|
||||||
|
head_id = int(head_id)
|
||||||
|
heads = self.state.align_source.get(layer_rank, [])
|
||||||
|
heads.append((self.state.num_align_heads, head_id))
|
||||||
|
self.state.align_source[layer_rank] = heads
|
||||||
|
self.state.num_align_heads += 1
|
||||||
|
|
||||||
|
# === Abstract method implementations ===
|
||||||
|
|
||||||
|
def init_tokens(self):
|
||||||
|
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||||
|
self.state.initial_tokens = mx.array(
|
||||||
|
[self.tokenizer.sot_sequence_including_notimestamps],
|
||||||
|
dtype=mx.int32,
|
||||||
|
)
|
||||||
|
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||||
|
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||||
|
self.state.tokens = [self.state.initial_tokens]
|
||||||
|
|
||||||
|
def init_context(self):
|
||||||
|
kw = {
|
||||||
|
'tokenizer': self.tokenizer,
|
||||||
|
'prefix_token_ids': [self.tokenizer.sot_prev],
|
||||||
|
}
|
||||||
|
self.state.context = MLXTokenBuffer.empty(**kw)
|
||||||
|
if self.cfg.static_init_prompt is not None:
|
||||||
|
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||||
|
if self.cfg.init_prompt is not None:
|
||||||
|
self.state.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
|
def insert_audio(self, segment=None):
|
||||||
|
if segment is not None:
|
||||||
|
if hasattr(segment, 'numpy'):
|
||||||
|
segment = segment.numpy()
|
||||||
|
self.state.segments.append(segment)
|
||||||
|
removed_len = 0
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
|
removed_len = self.state.segments[0].shape[0] / 16000
|
||||||
|
segments_len -= removed_len
|
||||||
|
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||||
|
self.state.cumulative_time_offset += removed_len
|
||||||
|
self.state.segments = self.state.segments[1:]
|
||||||
|
logger.debug(
|
||||||
|
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
|
||||||
|
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
|
||||||
|
)
|
||||||
|
if len(self.state.tokens) > 1:
|
||||||
|
token_list = np.array(self.state.tokens[1][0, :]).tolist()
|
||||||
|
self.state.context.append_token_ids(token_list)
|
||||||
|
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||||
|
return removed_len
|
||||||
|
|
||||||
|
def _current_tokens(self) -> mx.array:
|
||||||
|
toks = self.state.tokens
|
||||||
|
if toks[0].shape[0] == 1:
|
||||||
|
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
|
||||||
|
if not self.state.context.is_empty():
|
||||||
|
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
|
||||||
|
toks = [context_toks] + toks
|
||||||
|
if len(toks) > 1:
|
||||||
|
current_tokens = mx.concatenate(toks, axis=1)
|
||||||
|
else:
|
||||||
|
current_tokens = toks[0]
|
||||||
|
logger.debug("debug print current_tokens:")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
return current_tokens
|
||||||
|
|
||||||
|
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
|
||||||
|
if self.state.always_fire:
|
||||||
|
return True
|
||||||
|
if self.state.never_fire:
|
||||||
|
return False
|
||||||
|
return True # MLX CIF not implemented
|
||||||
|
|
||||||
|
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
|
||||||
|
n_audio = encoder_features.shape[0]
|
||||||
|
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
|
||||||
|
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
|
||||||
|
logits = logits[:, 0]
|
||||||
|
|
||||||
|
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
|
||||||
|
language_token_indices = mx.array(
|
||||||
|
list(self.tokenizer.all_language_tokens), dtype=mx.int32,
|
||||||
|
)
|
||||||
|
mask = mask.at[language_token_indices].add(False)
|
||||||
|
logits = mx.where(mask, mx.array(-float('inf')), logits)
|
||||||
|
|
||||||
|
language_tokens = mx.argmax(logits, axis=-1)
|
||||||
|
language_token_probs = mx.softmax(logits, axis=-1)
|
||||||
|
probs_np = np.array(language_token_probs)
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: float(probs_np[i, j])
|
||||||
|
for j, c in zip(
|
||||||
|
self.tokenizer.all_language_tokens,
|
||||||
|
self.tokenizer.all_language_codes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
self._clean_cache()
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
def _concat_segments(self):
|
||||||
|
if len(self.state.segments) > 1:
|
||||||
|
return np.concatenate(self.state.segments, axis=0)
|
||||||
|
return self.state.segments[0]
|
||||||
|
|
||||||
|
def _encode(self, input_segments):
|
||||||
|
mlx_mel_padded = mlx_log_mel_spectrogram(
|
||||||
|
audio=input_segments,
|
||||||
|
n_mels=self.model.dims.n_mels,
|
||||||
|
padding=N_SAMPLES,
|
||||||
|
)
|
||||||
|
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||||
|
encoder_feature = self.model.encoder(mlx_mel[None])
|
||||||
|
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
||||||
|
return encoder_feature, content_mel_len
|
||||||
|
|
||||||
|
def _init_sum_logprobs(self):
|
||||||
|
return mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
|
||||||
|
|
||||||
|
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||||
|
if self.state.decoder_type == "greedy":
|
||||||
|
logits, self.state.kv_cache, cross_qk = self.model.decoder(
|
||||||
|
tokens, encoder_feature, kv_cache=self.state.kv_cache,
|
||||||
|
)
|
||||||
|
return logits, cross_qk
|
||||||
|
else:
|
||||||
|
return self.state.inference.logits(tokens, encoder_feature)
|
||||||
|
|
||||||
|
def _check_no_speech(self, logits):
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||||
|
no_speech_probs = np.array(
|
||||||
|
probs_at_sot[:, self.tokenizer.no_speech],
|
||||||
|
).tolist()
|
||||||
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
|
logger.info("no speech, stop")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _suppress_blank_tokens(self, logits):
|
||||||
|
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
|
||||||
|
logits = logits.at[:, blank_tokens].add(-float('inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _apply_token_suppression(self, logits):
|
||||||
|
if self.state.suppress_tokens:
|
||||||
|
suppress_indices = mx.array(
|
||||||
|
list(self.state.suppress_tokens), dtype=mx.int32,
|
||||||
|
)
|
||||||
|
logits = logits.at[:, suppress_indices].add(-float('inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||||
|
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
|
def _process_cross_attention(
|
||||||
|
self, cross_attns: List, content_mel_len: int,
|
||||||
|
) -> mx.array:
|
||||||
|
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||||
|
num_decoder_layers = self.num_decoder_layers
|
||||||
|
|
||||||
|
if cross_attns and isinstance(cross_attns[0], list):
|
||||||
|
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
||||||
|
else:
|
||||||
|
flattened_attns = cross_attns
|
||||||
|
|
||||||
|
for idx, attn_mat in enumerate(flattened_attns):
|
||||||
|
if attn_mat is None:
|
||||||
|
continue
|
||||||
|
layer_rank = idx % num_decoder_layers
|
||||||
|
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||||
|
if not align_heads_in_layer:
|
||||||
|
continue
|
||||||
|
attn_mat = mx.softmax(attn_mat, axis=-1)
|
||||||
|
for align_head_rank, head_id in align_heads_in_layer:
|
||||||
|
if self.cfg.beam_size == 1:
|
||||||
|
if attn_mat.ndim == 4:
|
||||||
|
a = attn_mat[0, head_id, :, :]
|
||||||
|
else:
|
||||||
|
a = attn_mat[head_id, :, :]
|
||||||
|
a = a[None, :, :]
|
||||||
|
else:
|
||||||
|
a = attn_mat[:, head_id, :, :]
|
||||||
|
attn_of_alignment_heads[align_head_rank].append(a)
|
||||||
|
|
||||||
|
tmp = []
|
||||||
|
for mat in attn_of_alignment_heads:
|
||||||
|
if mat:
|
||||||
|
tmp.append(mx.concatenate(mat, axis=1))
|
||||||
|
if not tmp:
|
||||||
|
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
|
||||||
|
|
||||||
|
attn_of_alignment_heads = mx.stack(tmp, axis=1)
|
||||||
|
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||||
|
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||||
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||||
|
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
|
||||||
|
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||||
|
mx.eval(attn_of_alignment_heads)
|
||||||
|
return attn_of_alignment_heads
|
||||||
|
|
||||||
|
def _get_attended_frames(self, attn):
|
||||||
|
most_attended_frames = mx.argmax(attn[:, -1, :], axis=-1)
|
||||||
|
frames_np = np.array(most_attended_frames)
|
||||||
|
return frames_np.tolist(), int(frames_np[0])
|
||||||
|
|
||||||
|
def _is_special_token(self, current_tokens):
|
||||||
|
return int(np.array(current_tokens[0, -2])) >= DEC_PAD
|
||||||
|
|
||||||
|
def _rewind_tokens(self):
|
||||||
|
if len(self.state.tokens) > 0:
|
||||||
|
return mx.concatenate(self.state.tokens, axis=1)
|
||||||
|
return self.state.tokens[0]
|
||||||
|
|
||||||
|
def _tokens_to_list(self, current_tokens, start_col):
|
||||||
|
return np.array(current_tokens[0, start_col:]).tolist()
|
||||||
|
|
||||||
|
def _make_new_tokens_tensor(self, hypothesis):
|
||||||
|
new_tokens = mx.array([hypothesis], dtype=mx.int32)
|
||||||
|
return mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
|
||||||
|
|
||||||
|
def _evaluate(self, tensor):
|
||||||
|
mx.eval(tensor)
|
||||||
@@ -7,21 +7,9 @@ from huggingface_hub import snapshot_download
|
|||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
from mlx_whisper import whisper
|
from mlx_whisper import whisper
|
||||||
|
|
||||||
mlx_model_mapping = {
|
from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
|
||||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
|
||||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
mlx_model_mapping = MLX_MODEL_MAPPING
|
||||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
|
||||||
"base": "mlx-community/whisper-base-mlx",
|
|
||||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
|
||||||
"small": "mlx-community/whisper-small-mlx",
|
|
||||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
|
||||||
"medium": "mlx-community/whisper-medium-mlx",
|
|
||||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
|
||||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
|
||||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
|
||||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
|
||||||
"large": "mlx-community/whisper-large-mlx",
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_mlx_encoder(
|
def load_mlx_encoder(
|
||||||
path_or_hf_repo: str,
|
path_or_hf_repo: str,
|
||||||
@@ -69,3 +57,39 @@ def load_mlx_encoder(
|
|||||||
model.update(encoder_weights)
|
model.update(encoder_weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_mlx_model(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> whisper.Whisper:
|
||||||
|
model_path = Path(path_or_hf_repo)
|
||||||
|
if not model_path.exists():
|
||||||
|
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||||
|
|
||||||
|
with open(str(model_path / "config.json"), "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
|
quantization = config.pop("quantization", None)
|
||||||
|
|
||||||
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
|
wf = model_path / "weights.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.npz"
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
|
||||||
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
if quantization is not None:
|
||||||
|
class_predicate = (
|
||||||
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||||
|
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from time import time
|
from typing import List
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -9,8 +8,6 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from whisperlivekit.backend_support import (faster_backend_available,
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
mlx_backend_available)
|
mlx_backend_available)
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
|
||||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
|
||||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||||
TOKENS_PER_SECOND,
|
TOKENS_PER_SECOND,
|
||||||
log_mel_spectrogram, pad_or_trim)
|
log_mel_spectrogram, pad_or_trim)
|
||||||
@@ -18,14 +15,13 @@ from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
|||||||
SuppressTokens)
|
SuppressTokens)
|
||||||
from whisperlivekit.whisper.timing import median_filter
|
from whisperlivekit.whisper.timing import median_filter
|
||||||
|
|
||||||
from ..timed_objects import PUNCTUATION_MARKS
|
from .align_att_base import DEC_PAD, AlignAttBase
|
||||||
from .beam import BeamPyTorchInference
|
from .beam import BeamPyTorchInference
|
||||||
from .config import AlignAttConfig
|
from .config import AlignAttConfig
|
||||||
from .decoder_state import DecoderState
|
from .decoder_state import DecoderState
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
from .token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
DEC_PAD = 50257
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if mlx_backend_available():
|
if mlx_backend_available():
|
||||||
@@ -46,7 +42,10 @@ def load_coreml_encoder():
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("coremltools is not installed")
|
logger.warning("coremltools is not installed")
|
||||||
return None
|
return None
|
||||||
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
|
COREML_ENCODER_PATH = os.environ.get(
|
||||||
|
"MLCORE_ENCODER_PATH",
|
||||||
|
"whisperlivekit/whisper/whisper_encoder.mlpackage",
|
||||||
|
)
|
||||||
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
||||||
spec = _coreml_encoder.get_spec()
|
spec = _coreml_encoder.get_spec()
|
||||||
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
||||||
@@ -54,31 +53,14 @@ def load_coreml_encoder():
|
|||||||
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
||||||
|
|
||||||
|
|
||||||
class AlignAtt:
|
class AlignAtt(AlignAttBase):
|
||||||
"""
|
"""
|
||||||
Alignment-based Attention decoder for SimulStreaming.
|
PyTorch Alignment-based Attention decoder for SimulStreaming.
|
||||||
|
|
||||||
This class is now hookless - the model can be shared across multiple
|
Hookless — the model can be shared across multiple sessions,
|
||||||
sessions, with each session maintaining its own DecoderState.
|
with each session maintaining its own DecoderState.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Property accessors for backward compatibility
|
|
||||||
@property
|
|
||||||
def speaker(self):
|
|
||||||
return self.state.speaker
|
|
||||||
|
|
||||||
@speaker.setter
|
|
||||||
def speaker(self, value):
|
|
||||||
self.state.speaker = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_time_offset(self):
|
|
||||||
return self.state.global_time_offset
|
|
||||||
|
|
||||||
@global_time_offset.setter
|
|
||||||
def global_time_offset(self, value):
|
|
||||||
self.state.global_time_offset = value
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg: AlignAttConfig,
|
cfg: AlignAttConfig,
|
||||||
@@ -86,60 +68,35 @@ class AlignAtt:
|
|||||||
mlx_encoder=None,
|
mlx_encoder=None,
|
||||||
fw_encoder=None,
|
fw_encoder=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Shared model reference (can be shared across sessions)
|
|
||||||
self.model = loaded_model
|
|
||||||
self.mlx_encoder = mlx_encoder
|
self.mlx_encoder = mlx_encoder
|
||||||
self.fw_encoder = fw_encoder
|
self.fw_encoder = fw_encoder
|
||||||
if fw_encoder:
|
if fw_encoder:
|
||||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
self.fw_feature_extractor = FeatureExtractor(
|
||||||
|
feature_size=loaded_model.dims.n_mels,
|
||||||
|
)
|
||||||
self.coreml_encoder_tuple = None
|
self.coreml_encoder_tuple = None
|
||||||
if USE_MLCORE:
|
if USE_MLCORE:
|
||||||
self.coreml_encoder_tuple = load_coreml_encoder()
|
self.coreml_encoder_tuple = load_coreml_encoder()
|
||||||
self.use_mlcore = self.coreml_encoder_tuple is not None
|
self.use_mlcore = self.coreml_encoder_tuple is not None
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# Common init (sets self.model, self.cfg, decode_options, etc.)
|
||||||
|
self._base_init(cfg, loaded_model)
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
self.decode_options = DecodingOptions(
|
|
||||||
language=cfg.language,
|
|
||||||
without_timestamps=True,
|
|
||||||
task=cfg.task
|
|
||||||
)
|
|
||||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
|
||||||
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
# Per-session state
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
|
||||||
self.cfg = cfg
|
|
||||||
|
|
||||||
if self.cfg.max_context_tokens is None:
|
|
||||||
self.max_context_tokens = self.max_text_len
|
|
||||||
else:
|
|
||||||
self.max_context_tokens = self.cfg.max_context_tokens
|
|
||||||
|
|
||||||
# Initialize per-session state
|
|
||||||
self.state = DecoderState()
|
self.state = DecoderState()
|
||||||
self._init_state(cfg)
|
self._init_state(cfg)
|
||||||
|
|
||||||
def _init_state(self, cfg: AlignAttConfig):
|
def _init_state(self, cfg: AlignAttConfig):
|
||||||
"""Initialize the per-session decoder state."""
|
self._init_state_common(cfg)
|
||||||
# Create tokenizer
|
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
|
||||||
self.state.tokenizer = self.tokenizer
|
|
||||||
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
|
||||||
|
|
||||||
# Timing state
|
|
||||||
self.state.global_time_offset = 0.0
|
|
||||||
self.state.last_attend_frame = -cfg.rewind_threshold
|
|
||||||
self.state.speaker = -1
|
|
||||||
|
|
||||||
# CIF helpers for end-of-word boundary detection
|
# CIF helpers for end-of-word boundary detection
|
||||||
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
|
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
|
||||||
cfg,
|
cfg, n_audio_state=self.model.dims.n_audio_state, device=self.model.device,
|
||||||
n_audio_state=self.model.dims.n_audio_state,
|
|
||||||
device=self.model.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build alignment source mapping from model's alignment_heads
|
# Build alignment source mapping
|
||||||
self.state.align_source = {}
|
self.state.align_source = {}
|
||||||
self.state.num_align_heads = 0
|
self.state.num_align_heads = 0
|
||||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||||
@@ -151,12 +108,9 @@ class AlignAtt:
|
|||||||
|
|
||||||
# Build suppress tokens function
|
# Build suppress tokens function
|
||||||
suppress_tokens = [
|
suppress_tokens = [
|
||||||
self.tokenizer.transcribe,
|
self.tokenizer.transcribe, self.tokenizer.translate,
|
||||||
self.tokenizer.translate,
|
self.tokenizer.sot, self.tokenizer.sot_prev,
|
||||||
self.tokenizer.sot,
|
self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
|
||||||
self.tokenizer.sot_prev,
|
|
||||||
self.tokenizer.sot_lm,
|
|
||||||
self.tokenizer.no_timestamps,
|
|
||||||
] + list(self.tokenizer.all_language_tokens)
|
] + list(self.tokenizer.all_language_tokens)
|
||||||
if self.tokenizer.no_speech is not None:
|
if self.tokenizer.no_speech is not None:
|
||||||
suppress_tokens.append(self.tokenizer.no_speech)
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
@@ -165,138 +119,80 @@ class AlignAtt:
|
|||||||
sup_tokens = SuppressTokens(suppress_tokens)
|
sup_tokens = SuppressTokens(suppress_tokens)
|
||||||
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
|
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
|
||||||
|
|
||||||
# Initialize tokens
|
|
||||||
self.init_tokens()
|
self.init_tokens()
|
||||||
self.init_context()
|
self.init_context()
|
||||||
|
|
||||||
# Set up decoder type
|
# Decoder type
|
||||||
self.state.decoder_type = cfg.decoder_type
|
self.state.decoder_type = cfg.decoder_type
|
||||||
if cfg.decoder_type == "greedy":
|
if cfg.decoder_type == "greedy":
|
||||||
logger.info("Using greedy decoder")
|
logger.info("Using greedy decoder")
|
||||||
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||||
elif cfg.decoder_type == "beam":
|
elif cfg.decoder_type == "beam":
|
||||||
logger.info("Using beam decoder")
|
logger.info("Using beam decoder")
|
||||||
self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
|
self.state.inference = BeamPyTorchInference(
|
||||||
|
self.model, self.state.initial_token_length,
|
||||||
|
)
|
||||||
self.state.inference.kv_cache = self.state.kv_cache
|
self.state.inference.kv_cache = self.state.kv_cache
|
||||||
self.state.token_decoder = BeamSearchDecoder(
|
self.state.token_decoder = BeamSearchDecoder(
|
||||||
inference=self.state.inference,
|
inference=self.state.inference,
|
||||||
eot=self.tokenizer.eot,
|
eot=self.tokenizer.eot,
|
||||||
beam_size=cfg.beam_size
|
beam_size=cfg.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup(self, audio):
|
# === Abstract method implementations ===
|
||||||
try:
|
|
||||||
self.insert_audio(audio)
|
|
||||||
self.infer(is_last=True)
|
|
||||||
self.refresh_segment(complete=True)
|
|
||||||
logger.info("Model warmed up successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Model warmup failed: {e}")
|
|
||||||
|
|
||||||
def create_tokenizer(self, language=None):
|
def init_tokens(self):
|
||||||
self.tokenizer = tokenizer.get_tokenizer(
|
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||||
multilingual=self.tokenizer_is_multilingual,
|
self.state.initial_tokens = torch.tensor(
|
||||||
language=language,
|
self.tokenizer.sot_sequence_including_notimestamps,
|
||||||
num_languages=self.model.num_languages,
|
dtype=torch.long, device=self.model.device,
|
||||||
task=self.decode_options.task
|
).unsqueeze(0)
|
||||||
)
|
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||||
self.state.tokenizer = self.tokenizer
|
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||||
|
self.state.tokens = [self.state.initial_tokens]
|
||||||
|
|
||||||
def init_context(self):
|
def init_context(self):
|
||||||
kw = {'tokenizer': self.tokenizer,
|
kw = {
|
||||||
|
'tokenizer': self.tokenizer,
|
||||||
'device': self.model.device,
|
'device': self.model.device,
|
||||||
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
'prefix_token_ids': [self.tokenizer.sot_prev],
|
||||||
|
}
|
||||||
self.state.context = TokenBuffer.empty(**kw)
|
self.state.context = TokenBuffer.empty(**kw)
|
||||||
if self.cfg.static_init_prompt is not None:
|
if self.cfg.static_init_prompt is not None:
|
||||||
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||||
if self.cfg.init_prompt is not None:
|
if self.cfg.init_prompt is not None:
|
||||||
self.state.context.text += self.cfg.init_prompt
|
self.state.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
def init_tokens(self):
|
def insert_audio(self, segment=None):
|
||||||
logger.debug(f"init tokens, {len(self.state.segments)}")
|
if segment is not None:
|
||||||
# init tokens (mandatory prompt)
|
self.state.segments.append(segment)
|
||||||
self.state.initial_tokens = torch.tensor(
|
removed_len = 0
|
||||||
self.tokenizer.sot_sequence_including_notimestamps,
|
segments_len = self.segments_len()
|
||||||
dtype=torch.long,
|
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
device=self.model.device).unsqueeze(0)
|
removed_len = self.state.segments[0].shape[0] / 16000
|
||||||
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
segments_len -= removed_len
|
||||||
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||||
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
self.state.cumulative_time_offset += removed_len
|
||||||
self.state.tokens = [self.state.initial_tokens]
|
self.state.segments = self.state.segments[1:]
|
||||||
|
logger.debug(
|
||||||
def trim_context(self):
|
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
|
||||||
logger.info("Trimming context")
|
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
|
||||||
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
|
||||||
logger.info(f"Context text: {self.state.context.as_text()}")
|
|
||||||
l = sum(t.shape[1] for t in self.state.tokens) + c
|
|
||||||
if self.cfg.static_init_prompt is None:
|
|
||||||
after = 0
|
|
||||||
else:
|
|
||||||
after = len(self.cfg.static_init_prompt)
|
|
||||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
|
||||||
t = self.state.context.trim_words(after=after)
|
|
||||||
l -= t
|
|
||||||
c -= t
|
|
||||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
|
||||||
if t == 0:
|
|
||||||
break
|
|
||||||
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
|
||||||
|
|
||||||
|
|
||||||
def logits(
|
|
||||||
self,
|
|
||||||
tokens: torch.Tensor,
|
|
||||||
audio_features: torch.Tensor,
|
|
||||||
return_cross_attn: bool = False
|
|
||||||
):
|
|
||||||
"""Get logits from decoder, optionally returning cross-attention weights."""
|
|
||||||
if self.state.decoder_type == "greedy":
|
|
||||||
return self.model.decoder(
|
|
||||||
tokens, audio_features,
|
|
||||||
kv_cache=self.state.kv_cache,
|
|
||||||
return_cross_attn=return_cross_attn
|
|
||||||
)
|
)
|
||||||
else:
|
if len(self.state.tokens) > 1:
|
||||||
logger.debug(f"Logits shape: {tokens.shape}")
|
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
|
||||||
return self.state.inference.logits(
|
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||||
tokens, audio_features,
|
return removed_len
|
||||||
return_cross_attn=return_cross_attn
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_segment(self, complete=False):
|
|
||||||
logger.debug("Refreshing segment:")
|
|
||||||
self.init_tokens()
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
self.state.cumulative_time_offset = 0.0
|
|
||||||
self.init_context()
|
|
||||||
logger.debug(f"Context: {self.state.context}")
|
|
||||||
if not complete and len(self.state.segments) > 2:
|
|
||||||
self.state.segments = self.state.segments[-2:]
|
|
||||||
else:
|
|
||||||
logger.debug("removing all segments.")
|
|
||||||
self.state.segments = []
|
|
||||||
self.state.log_segments += 1
|
|
||||||
self.state.pending_incomplete_tokens = []
|
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
|
||||||
if self.state.always_fire:
|
|
||||||
return True
|
|
||||||
if self.state.never_fire:
|
|
||||||
return False
|
|
||||||
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
|
|
||||||
|
|
||||||
def _current_tokens(self):
|
def _current_tokens(self):
|
||||||
toks = self.state.tokens
|
toks = self.state.tokens
|
||||||
# very first infer: duplicate start of seq to beam_size
|
|
||||||
if toks[0].shape[0] == 1:
|
if toks[0].shape[0] == 1:
|
||||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
|
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
|
||||||
|
|
||||||
if not self.state.context.is_empty():
|
if not self.state.context.is_empty():
|
||||||
context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
context_toks = self.state.context.as_tensor_beam(
|
||||||
|
self.cfg.beam_size, device=self.model.device,
|
||||||
|
)
|
||||||
toks = [context_toks] + toks
|
toks = [context_toks] + toks
|
||||||
|
|
||||||
# make it one tensor
|
|
||||||
if len(toks) > 1:
|
if len(toks) > 1:
|
||||||
current_tokens = torch.cat(toks, dim=1)
|
current_tokens = torch.cat(toks, dim=1)
|
||||||
else:
|
else:
|
||||||
@@ -305,60 +201,19 @@ class AlignAtt:
|
|||||||
self.debug_print_tokens(current_tokens)
|
self.debug_print_tokens(current_tokens)
|
||||||
return current_tokens
|
return current_tokens
|
||||||
|
|
||||||
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
def debug_print_tokens(self, tokens):
|
if self.state.always_fire:
|
||||||
for i in range(self.cfg.beam_size):
|
|
||||||
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
|
||||||
|
|
||||||
### audio buffer
|
|
||||||
|
|
||||||
def segments_len(self):
|
|
||||||
segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
|
|
||||||
return segments_len
|
|
||||||
|
|
||||||
def _apply_minseglen(self):
|
|
||||||
segments_len = self.segments_len()
|
|
||||||
# wait for long enough audio to start
|
|
||||||
if segments_len < self.cfg.audio_min_len:
|
|
||||||
logger.debug("waiting for next segment")
|
|
||||||
return False
|
|
||||||
return True
|
return True
|
||||||
|
if self.state.never_fire:
|
||||||
def insert_audio(self, segment=None):
|
return False
|
||||||
if segment is not None:
|
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
|
||||||
self.state.segments.append(segment)
|
|
||||||
|
|
||||||
removed_len = 0
|
|
||||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
|
||||||
segments_len = self.segments_len()
|
|
||||||
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
|
||||||
removed_len = self.state.segments[0].shape[0] / 16000
|
|
||||||
segments_len -= removed_len
|
|
||||||
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
|
||||||
self.state.cumulative_time_offset += removed_len # Track cumulative time removed
|
|
||||||
self.state.segments = self.state.segments[1:]
|
|
||||||
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
|
|
||||||
if len(self.state.tokens) > 1:
|
|
||||||
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
|
|
||||||
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
|
||||||
return removed_len
|
|
||||||
|
|
||||||
def _clean_cache(self):
|
|
||||||
"""Clean the kv_cache after each inference step."""
|
|
||||||
self.state.clean_cache()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def lang_id(self, encoder_features):
|
def lang_id(self, encoder_features):
|
||||||
"""Language detection from encoder features.
|
|
||||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language.
|
|
||||||
"""
|
|
||||||
# forward pass using a single token, startoftranscript
|
|
||||||
n_audio = encoder_features.shape[0]
|
n_audio = encoder_features.shape[0]
|
||||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device)
|
||||||
# Note: don't use kv_cache for language detection
|
|
||||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||||
|
|
||||||
# collect detected languages; suppress all non-language tokens
|
|
||||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
mask[list(self.tokenizer.all_language_tokens)] = False
|
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||||
logits[:, mask] = -np.inf
|
logits[:, mask] = -np.inf
|
||||||
@@ -367,46 +222,31 @@ class AlignAtt:
|
|||||||
language_probs = [
|
language_probs = [
|
||||||
{
|
{
|
||||||
c: language_token_probs[i, j].item()
|
c: language_token_probs[i, j].item()
|
||||||
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
for j, c in zip(
|
||||||
|
self.tokenizer.all_language_tokens,
|
||||||
|
self.tokenizer.all_language_codes,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
for i in range(n_audio)
|
for i in range(n_audio)
|
||||||
]
|
]
|
||||||
|
|
||||||
single = encoder_features.ndim == 2
|
single = encoder_features.ndim == 2
|
||||||
if single:
|
if single:
|
||||||
language_tokens = language_tokens[0]
|
language_tokens = language_tokens[0]
|
||||||
language_probs = language_probs[0]
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
return language_tokens, language_probs
|
return language_tokens, language_probs
|
||||||
|
|
||||||
### transcription / translation
|
def _concat_segments(self):
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def infer(self, is_last=False):
|
|
||||||
new_segment = True
|
|
||||||
if len(self.state.segments) == 0:
|
|
||||||
logger.debug("No segments, nothing to do")
|
|
||||||
return []
|
|
||||||
if not self._apply_minseglen():
|
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
|
||||||
input_segments = torch.cat(self.state.segments, dim=0)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
|
||||||
if len(self.state.segments) > 1:
|
if len(self.state.segments) > 1:
|
||||||
input_segments = torch.cat(self.state.segments, dim=0)
|
return torch.cat(self.state.segments, dim=0)
|
||||||
else:
|
return self.state.segments[0]
|
||||||
input_segments = self.state.segments[0]
|
|
||||||
|
|
||||||
beg_encode = time()
|
def _encode(self, input_segments):
|
||||||
if self.use_mlcore:
|
if self.use_mlcore:
|
||||||
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
||||||
mel_padded = log_mel_spectrogram(
|
mel_padded = log_mel_spectrogram(
|
||||||
input_segments,
|
input_segments, n_mels=self.model.dims.n_mels,
|
||||||
n_mels=self.model.dims.n_mels,
|
padding=N_SAMPLES, device="cpu",
|
||||||
padding=N_SAMPLES,
|
|
||||||
device="cpu",
|
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||||
@@ -418,288 +258,151 @@ class AlignAtt:
|
|||||||
else:
|
else:
|
||||||
encoder_feature_np = next(iter(coreml_outputs.values()))
|
encoder_feature_np = next(iter(coreml_outputs.values()))
|
||||||
encoder_feature = torch.as_tensor(
|
encoder_feature = torch.as_tensor(
|
||||||
np.array(encoder_feature_np),
|
np.array(encoder_feature_np), device=self.device,
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
if self.mlx_encoder:
|
if self.mlx_encoder:
|
||||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
mlx_mel_padded = mlx_log_mel_spectrogram(
|
||||||
|
audio=input_segments.detach(),
|
||||||
|
n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||||
|
)
|
||||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
||||||
elif self.fw_encoder:
|
elif self.fw_encoder:
|
||||||
audio_length_seconds = len(input_segments) / 16000
|
audio_length_seconds = len(input_segments) / 16000
|
||||||
content_mel_len = int(audio_length_seconds * 100)//2
|
content_mel_len = int(audio_length_seconds * 100) // 2
|
||||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
mel_padded_2 = self.fw_feature_extractor(
|
||||||
|
waveform=input_segments.numpy(), padding=N_SAMPLES,
|
||||||
|
)[None, :]
|
||||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||||
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
|
if self.device == 'cpu':
|
||||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||||
try:
|
try:
|
||||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||||
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
|
except TypeError:
|
||||||
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
|
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
||||||
|
arr = np.array(encoder_feature_ctranslate)
|
||||||
|
if arr.dtype == np.object_:
|
||||||
|
arr = np.array(arr.tolist(), dtype=np.float32)
|
||||||
|
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||||
else:
|
else:
|
||||||
# mel + padding to 30s
|
mel_padded = log_mel_spectrogram(
|
||||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
input_segments, n_mels=self.model.dims.n_mels,
|
||||||
device=self.device).unsqueeze(0)
|
padding=N_SAMPLES, device=self.device,
|
||||||
# trim to 3000
|
).unsqueeze(0)
|
||||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||||
# the len of actual audio
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
|
||||||
encoder_feature = self.model.encoder(mel)
|
encoder_feature = self.model.encoder(mel)
|
||||||
end_encode = time()
|
return encoder_feature, content_mel_len
|
||||||
# print('Encoder duration:', end_encode-beg_encode)
|
|
||||||
|
|
||||||
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
|
def _init_sum_logprobs(self):
|
||||||
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
return torch.zeros(self.cfg.beam_size, device=self.device)
|
||||||
if seconds_since_start >= 2.0:
|
|
||||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
|
||||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
|
||||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
|
||||||
self.create_tokenizer(top_lan)
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
self.state.cumulative_time_offset = 0.0
|
|
||||||
self.init_tokens()
|
|
||||||
self.init_context()
|
|
||||||
self.state.detected_language = top_lan
|
|
||||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
|
||||||
|
|
||||||
self.trim_context()
|
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||||
current_tokens = self._current_tokens()
|
if self.state.decoder_type == "greedy":
|
||||||
|
return self.model.decoder(
|
||||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
tokens, encoder_feature,
|
||||||
|
kv_cache=self.state.kv_cache,
|
||||||
|
return_cross_attn=True,
|
||||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
)
|
||||||
completed = False
|
|
||||||
# punctuation_stop = False
|
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
|
||||||
most_attended_frame = None
|
|
||||||
|
|
||||||
token_len_before_decoding = current_tokens.shape[1]
|
|
||||||
|
|
||||||
l_absolute_timestamps = []
|
|
||||||
|
|
||||||
accumulated_cross_attns = []
|
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
|
||||||
|
|
||||||
if new_segment:
|
|
||||||
tokens_for_logits = current_tokens
|
|
||||||
else:
|
else:
|
||||||
# only need to use the last token except in the first forward pass
|
logger.debug(f"Logits shape: {tokens.shape}")
|
||||||
tokens_for_logits = current_tokens[:, -1:]
|
return self.state.inference.logits(
|
||||||
|
tokens, encoder_feature, return_cross_attn=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Get logits and cross-attention weights from decoder
|
def _check_no_speech(self, logits):
|
||||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
if self.tokenizer.no_speech is not None:
|
||||||
logits, cross_attns = result
|
|
||||||
|
|
||||||
# Accumulate cross-attention from this forward pass
|
|
||||||
accumulated_cross_attns.append(cross_attns)
|
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
|
||||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
logger.info("no speech, stop")
|
logger.info("no speech, stop")
|
||||||
break
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
logits = logits[:, -1, :] # logits for the last token
|
def _suppress_blank_tokens(self, logits):
|
||||||
|
|
||||||
# suppress blank tokens only at the beginning of the segment
|
|
||||||
if new_segment:
|
|
||||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
new_segment = False
|
return logits
|
||||||
|
|
||||||
|
def _apply_token_suppression(self, logits):
|
||||||
self.state.suppress_tokens_fn(logits)
|
self.state.suppress_tokens_fn(logits)
|
||||||
current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
return logits
|
||||||
|
|
||||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||||
self.debug_print_tokens(current_tokens)
|
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
# Process accumulated cross-attention weights for alignment
|
|
||||||
attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
|
|
||||||
|
|
||||||
# for each beam, the most attended frame is:
|
|
||||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
|
|
||||||
|
|
||||||
# Calculate absolute timestamps accounting for cumulative offset
|
|
||||||
absolute_timestamps = [
|
|
||||||
(frame * 0.02 + self.state.cumulative_time_offset)
|
|
||||||
for frame in most_attended_frames.tolist()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
|
||||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
|
|
||||||
|
|
||||||
most_attended_frame = most_attended_frames[0].item()
|
|
||||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
|
||||||
|
|
||||||
logger.debug("current tokens" + str(current_tokens.shape))
|
|
||||||
if completed:
|
|
||||||
# stripping the last token, the eot
|
|
||||||
current_tokens = current_tokens[:, :-1]
|
|
||||||
break
|
|
||||||
|
|
||||||
# for some rare cases where the attention fails
|
|
||||||
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
|
||||||
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
|
||||||
logger.debug("omit rewinding from special tokens")
|
|
||||||
self.state.last_attend_frame = most_attended_frame
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
|
||||||
f"last attention pos: {self.state.last_attend_frame}; omit this segment")
|
|
||||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.state.last_attend_frame = most_attended_frame
|
|
||||||
|
|
||||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
|
||||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
|
||||||
# stripping the last token, the one that is attended too close to the end
|
|
||||||
current_tokens = current_tokens[:, :-1]
|
|
||||||
break
|
|
||||||
|
|
||||||
# debug print
|
|
||||||
for i in range(self.cfg.beam_size):
|
|
||||||
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
|
|
||||||
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
|
|
||||||
most_attended_frames[i],
|
|
||||||
current_tokens[i, -1].item(),
|
|
||||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
|
||||||
))
|
|
||||||
|
|
||||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
|
||||||
|
|
||||||
# Prepend pending tokens from previous chunk if any
|
|
||||||
if self.state.pending_incomplete_tokens:
|
|
||||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
|
|
||||||
pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
|
||||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
|
||||||
|
|
||||||
if fire_detected or is_last:
|
|
||||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
|
||||||
else:
|
|
||||||
# going to truncate the tokens after the last space
|
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
|
||||||
if len(split_words) > 1:
|
|
||||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
|
||||||
else:
|
|
||||||
new_hypothesis = []
|
|
||||||
|
|
||||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
|
||||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.state.tokens.append(new_tokens)
|
|
||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
|
||||||
|
|
||||||
self._clean_cache()
|
|
||||||
|
|
||||||
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
|
||||||
self.state.first_timestamp = l_absolute_timestamps[0]
|
|
||||||
|
|
||||||
timestamped_words = []
|
|
||||||
timestamp_idx = 0
|
|
||||||
replacement_char = "\ufffd"
|
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
|
||||||
# Skip words containing incomplete UTF-8 from client output
|
|
||||||
if replacement_char in word:
|
|
||||||
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
|
||||||
timestamp_idx += len(word_tokens)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
timestamp_idx += len(word_tokens)
|
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
|
||||||
start=round(current_timestamp, 2),
|
|
||||||
end=round(current_timestamp + 0.1, 2),
|
|
||||||
text=word,
|
|
||||||
speaker=self.state.speaker,
|
|
||||||
detected_language=self.state.detected_language
|
|
||||||
).with_offset(
|
|
||||||
self.state.global_time_offset
|
|
||||||
)
|
|
||||||
timestamped_words.append(timestamp_entry)
|
|
||||||
|
|
||||||
# Hold incomplete tokens for next chunk
|
|
||||||
self.state.pending_incomplete_tokens = []
|
|
||||||
if split_words and replacement_char in split_words[-1]:
|
|
||||||
self.state.pending_incomplete_tokens = split_tokens[-1]
|
|
||||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}")
|
|
||||||
|
|
||||||
return timestamped_words
|
|
||||||
|
|
||||||
def _process_cross_attention(
|
def _process_cross_attention(
|
||||||
self,
|
self, cross_attns: List, content_mel_len: int,
|
||||||
cross_attns: List[torch.Tensor],
|
|
||||||
content_mel_len: int
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Process cross-attention weights from decoder layers for alignment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_attns: List of cross-attention tensors from each decoder layer.
|
|
||||||
Each tensor has shape (batch, n_head, seq_len, audio_len)
|
|
||||||
content_mel_len: Length of actual audio content in mel frames
|
|
||||||
|
|
||||||
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
|
|
||||||
"""
|
|
||||||
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||||
num_decoder_layers = len(self.model.decoder.blocks)
|
num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
|
|
||||||
if cross_attns and isinstance(cross_attns[0], list):
|
if cross_attns and isinstance(cross_attns[0], list):
|
||||||
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
|
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
||||||
else:
|
else:
|
||||||
flattened_attns = cross_attns
|
flattened_attns = cross_attns
|
||||||
|
|
||||||
for idx, attn_mat in enumerate(flattened_attns):
|
for idx, attn_mat in enumerate(flattened_attns):
|
||||||
layer_rank = idx % num_decoder_layers
|
layer_rank = idx % num_decoder_layers
|
||||||
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
|
|
||||||
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||||
if len(align_heads_in_layer) == 0:
|
if not align_heads_in_layer:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
attn_mat = F.softmax(attn_mat, dim=-1)
|
attn_mat = F.softmax(attn_mat, dim=-1)
|
||||||
|
|
||||||
for align_head_rank, head_id in align_heads_in_layer:
|
for align_head_rank, head_id in align_heads_in_layer:
|
||||||
if self.cfg.beam_size == 1:
|
if self.cfg.beam_size == 1:
|
||||||
# (n_head, seq_len, audio_len) when squeezed
|
|
||||||
if attn_mat.dim() == 4:
|
if attn_mat.dim() == 4:
|
||||||
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
|
a = attn_mat[0, head_id, :, :]
|
||||||
else:
|
else:
|
||||||
a = attn_mat[head_id, :, :]
|
a = attn_mat[head_id, :, :]
|
||||||
a = a.unsqueeze(0) # (1, seq_len, audio_len)
|
a = a.unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
# attn_mat: (batch, n_head, seq_len, audio_len)
|
a = attn_mat[:, head_id, :, :]
|
||||||
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
|
|
||||||
attn_of_alignment_heads[align_head_rank].append(a)
|
attn_of_alignment_heads[align_head_rank].append(a)
|
||||||
|
|
||||||
tmp = []
|
tmp = []
|
||||||
for mat in attn_of_alignment_heads:
|
for mat in attn_of_alignment_heads:
|
||||||
if mat:
|
if mat:
|
||||||
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
|
tmp.append(torch.cat(mat, dim=1))
|
||||||
tmp.append(t)
|
|
||||||
|
|
||||||
if not tmp:
|
if not tmp:
|
||||||
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
|
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
|
||||||
|
|
||||||
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
|
|
||||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||||
|
std, mean = torch.std_mean(
|
||||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False,
|
||||||
|
)
|
||||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||||
|
|
||||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
|
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||||
return attn_of_alignment_heads
|
return attn_of_alignment_heads
|
||||||
|
|
||||||
|
def _get_attended_frames(self, attn):
|
||||||
|
most_attended_frames = torch.argmax(attn[:, -1, :], dim=-1)
|
||||||
|
return most_attended_frames.tolist(), most_attended_frames[0].item()
|
||||||
|
|
||||||
|
def _is_special_token(self, current_tokens):
|
||||||
|
return current_tokens[0, -2].item() >= DEC_PAD
|
||||||
|
|
||||||
|
def _rewind_tokens(self):
|
||||||
|
if len(self.state.tokens) > 0:
|
||||||
|
return torch.cat(self.state.tokens, dim=1)
|
||||||
|
return self.state.tokens[0]
|
||||||
|
|
||||||
|
def _tokens_to_list(self, current_tokens, start_col):
|
||||||
|
return current_tokens[0, start_col:].flatten().tolist()
|
||||||
|
|
||||||
|
def _make_new_tokens_tensor(self, hypothesis):
|
||||||
|
return (
|
||||||
|
torch.tensor([hypothesis], dtype=torch.long)
|
||||||
|
.repeat_interleave(self.cfg.beam_size, dim=0)
|
||||||
|
.to(device=self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _evaluate(self, tensor):
|
||||||
|
pass # No-op for PyTorch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def infer(self, is_last=False):
|
||||||
|
return super().infer(is_last)
|
||||||
|
|||||||
139
whisperlivekit/thread_safety.py
Normal file
139
whisperlivekit/thread_safety.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Thread Safety Configuration for WhisperLiveKit
|
||||||
|
|
||||||
|
This module provides thread safety configuration and utilities.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
|
||||||
|
Set to "0" to disable for single-connection deployments
|
||||||
|
|
||||||
|
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Enable model locking (default)
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=1
|
||||||
|
|
||||||
|
# Disable for single-connection deployment
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=0
|
||||||
|
|
||||||
|
# Custom timeout
|
||||||
|
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
|
||||||
|
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
|
||||||
|
|
||||||
|
# Global model lock
|
||||||
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Log configuration on import
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
|
||||||
|
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
|
||||||
|
else:
|
||||||
|
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_lock():
|
||||||
|
"""Get the global model lock instance"""
|
||||||
|
return _model_lock
|
||||||
|
|
||||||
|
|
||||||
|
def acquire_model_lock(timeout=None):
|
||||||
|
"""
|
||||||
|
Acquire model lock with timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if lock acquired, False on timeout
|
||||||
|
"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return True
|
||||||
|
|
||||||
|
timeout = timeout or LOCK_TIMEOUT
|
||||||
|
acquired = _model_lock.acquire(timeout=timeout)
|
||||||
|
|
||||||
|
if not acquired:
|
||||||
|
logger.error(f"Failed to acquire model lock within {timeout}s")
|
||||||
|
|
||||||
|
return acquired
|
||||||
|
|
||||||
|
|
||||||
|
def release_model_lock():
|
||||||
|
"""Release model lock"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
_model_lock.release()
|
||||||
|
except RuntimeError:
|
||||||
|
# Lock not held - this is fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLockContext:
|
||||||
|
"""Context manager for model lock"""
|
||||||
|
|
||||||
|
def __init__(self, timeout=None):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.acquired = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquired = acquire_model_lock(self.timeout)
|
||||||
|
return self.acquired
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.acquired:
|
||||||
|
release_model_lock()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Concurrency recommendations
|
||||||
|
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
|
||||||
|
RECOMMENDED_WORKERS = 4
|
||||||
|
|
||||||
|
def print_deployment_recommendations():
|
||||||
|
"""Print recommended deployment configuration"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("WhisperLiveKit Deployment Recommendations")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
print("⚠️ Model locking is ENABLED")
|
||||||
|
print(" This serializes inference across connections.")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
|
||||||
|
print(" -k uvicorn.workers.UvicornWorker \\")
|
||||||
|
print(" --worker-connections 1 \\")
|
||||||
|
print(" whisperlivekit.basic_server:app")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
|
||||||
|
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
|
||||||
|
else:
|
||||||
|
print("✅ Model locking is DISABLED")
|
||||||
|
print(" ⚠️ ONLY safe for single-connection deployments")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(" uvicorn whisperlivekit.basic_server:app \\")
|
||||||
|
print(" --host 0.0.0.0 --port 8000 \\")
|
||||||
|
print(" --workers 1")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(" - 1 concurrent user only")
|
||||||
|
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_deployment_recommendations()
|
||||||
@@ -39,10 +39,11 @@ class TimedText(Timed):
|
|||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
probability: Optional[float] = None
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> "ASRToken":
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
"""Return a new token with the time offset added."""
|
"""Return a new token with the time offset added."""
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
|
||||||
|
|
||||||
def is_silence(self) -> bool:
|
def is_silence(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -114,6 +115,9 @@ class Segment(TimedText):
|
|||||||
end: Optional[float]
|
end: Optional[float]
|
||||||
text: Optional[str]
|
text: Optional[str]
|
||||||
speaker: Optional[str]
|
speaker: Optional[str]
|
||||||
|
tokens: Optional[ASRToken] = None
|
||||||
|
translation: Optional[Translation] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokens(
|
def from_tokens(
|
||||||
cls,
|
cls,
|
||||||
@@ -141,17 +145,13 @@ class Segment(TimedText):
|
|||||||
speaker=-1,
|
speaker=-1,
|
||||||
detected_language=start_token.detected_language
|
detected_language=start_token.detected_language
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_silence(self) -> bool:
|
def is_silence(self) -> bool:
|
||||||
"""True when this segment represents a silence gap."""
|
"""True when this segment represents a silence gap."""
|
||||||
return self.speaker == -2
|
return self.speaker == -2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Line(TimedText):
|
|
||||||
translation: str = ''
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Serialize the line for frontend consumption."""
|
"""Serialize the segment for frontend consumption."""
|
||||||
_dict: Dict[str, Any] = {
|
_dict: Dict[str, Any] = {
|
||||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||||
'text': self.text,
|
'text': self.text,
|
||||||
@@ -164,28 +164,12 @@ class Line(TimedText):
|
|||||||
_dict['detected_language'] = self.detected_language
|
_dict['detected_language'] = self.detected_language
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
|
|
||||||
"""Populate line attributes from a contiguous token list."""
|
|
||||||
self.text = ''.join([token.text for token in tokens])
|
|
||||||
self.start = tokens[0].start
|
|
||||||
self.end = tokens[-1].end
|
|
||||||
self.speaker = 1
|
|
||||||
self.detected_language = tokens[0].detected_language
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build_from_segment(self, segment: Segment) -> "Line":
|
@dataclass
|
||||||
"""Populate the line fields from a pre-built segment."""
|
class PuncSegment(Segment):
|
||||||
self.text = segment.text
|
pass
|
||||||
self.start = segment.start
|
|
||||||
self.end = segment.end
|
|
||||||
self.speaker = segment.speaker
|
|
||||||
self.detected_language = segment.detected_language
|
|
||||||
return self
|
|
||||||
|
|
||||||
def is_silent(self) -> bool:
|
class SilentSegment(Segment):
|
||||||
return self.speaker == -2
|
|
||||||
|
|
||||||
class SilentLine(Line):
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.speaker = -2
|
self.speaker = -2
|
||||||
@@ -196,7 +180,7 @@ class SilentLine(Line):
|
|||||||
class FrontData():
|
class FrontData():
|
||||||
status: str = ''
|
status: str = ''
|
||||||
error: str = ''
|
error: str = ''
|
||||||
lines: list[Line] = field(default_factory=list)
|
lines: list[Segment] = field(default_factory=list)
|
||||||
buffer_transcription: str = ''
|
buffer_transcription: str = ''
|
||||||
buffer_diarization: str = ''
|
buffer_diarization: str = ''
|
||||||
buffer_translation: str = ''
|
buffer_translation: str = ''
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
|
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
|
||||||
SilentLine, SpeakerSegment,
|
SilentSegment, SpeakerSegment,
|
||||||
TimedText)
|
TimedText)
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +27,14 @@ class TokensAlignment:
|
|||||||
self.sep: str = sep if sep is not None else ' '
|
self.sep: str = sep if sep is not None else ' '
|
||||||
self.beg_loop: Optional[float] = None
|
self.beg_loop: Optional[float] = None
|
||||||
|
|
||||||
|
self.validated_segments: List[Segment] = []
|
||||||
|
self.current_line_tokens: List[ASRToken] = []
|
||||||
|
self.diarization_buffer: List[ASRToken] = []
|
||||||
|
|
||||||
|
self.last_punctuation = None
|
||||||
|
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||||
|
self.unvalidated_tokens: PuncSegment = []
|
||||||
|
|
||||||
def update(self) -> None:
|
def update(self) -> None:
|
||||||
"""Drain state buffers into the running alignment context."""
|
"""Drain state buffers into the running alignment context."""
|
||||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||||
@@ -39,27 +47,30 @@ class TokensAlignment:
|
|||||||
self.all_translation_segments.extend(self.new_translation)
|
self.all_translation_segments.extend(self.new_translation)
|
||||||
self.new_translation_buffer = self.state.new_translation_buffer
|
self.new_translation_buffer = self.state.new_translation_buffer
|
||||||
|
|
||||||
def add_translation(self, line: Line) -> None:
|
def add_translation(self, segment: Segment) -> None:
|
||||||
"""Append translated text segments that overlap with a line."""
|
"""Append translated text segments that overlap with a segment."""
|
||||||
|
if segment.translation is None:
|
||||||
|
segment.translation = ''
|
||||||
for ts in self.all_translation_segments:
|
for ts in self.all_translation_segments:
|
||||||
if ts.is_within(line):
|
if ts.is_within(segment):
|
||||||
line.translation += ts.text + (self.sep if ts.text else '')
|
if ts.text:
|
||||||
elif line.translation:
|
segment.translation += ts.text + self.sep
|
||||||
|
elif segment.translation:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
|
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
|
||||||
"""Group tokens into segments split by punctuation and explicit silence."""
|
"""Group tokens into segments split by punctuation and explicit silence."""
|
||||||
segments = []
|
segments = []
|
||||||
segment_start_idx = 0
|
segment_start_idx = 0
|
||||||
for i, token in enumerate(self.all_tokens):
|
for i, token in enumerate(self.all_tokens):
|
||||||
if token.is_silence():
|
if token.is_silence():
|
||||||
previous_segment = Segment.from_tokens(
|
previous_segment = PuncSegment.from_tokens(
|
||||||
tokens=self.all_tokens[segment_start_idx: i],
|
tokens=self.all_tokens[segment_start_idx: i],
|
||||||
)
|
)
|
||||||
if previous_segment:
|
if previous_segment:
|
||||||
segments.append(previous_segment)
|
segments.append(previous_segment)
|
||||||
segment = Segment.from_tokens(
|
segment = PuncSegment.from_tokens(
|
||||||
tokens=[token],
|
tokens=[token],
|
||||||
is_silence=True
|
is_silence=True
|
||||||
)
|
)
|
||||||
@@ -67,19 +78,47 @@ class TokensAlignment:
|
|||||||
segment_start_idx = i+1
|
segment_start_idx = i+1
|
||||||
else:
|
else:
|
||||||
if token.has_punctuation():
|
if token.has_punctuation():
|
||||||
segment = Segment.from_tokens(
|
segment = PuncSegment.from_tokens(
|
||||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||||
)
|
)
|
||||||
segments.append(segment)
|
segments.append(segment)
|
||||||
segment_start_idx = i+1
|
segment_start_idx = i+1
|
||||||
|
|
||||||
final_segment = Segment.from_tokens(
|
final_segment = PuncSegment.from_tokens(
|
||||||
tokens=self.all_tokens[segment_start_idx:],
|
tokens=self.all_tokens[segment_start_idx:],
|
||||||
)
|
)
|
||||||
if final_segment:
|
if final_segment:
|
||||||
segments.append(final_segment)
|
segments.append(final_segment)
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
||||||
|
new_punc_segments = []
|
||||||
|
segment_start_idx = 0
|
||||||
|
self.unvalidated_tokens += self.new_tokens
|
||||||
|
for i, token in enumerate(self.unvalidated_tokens):
|
||||||
|
if token.is_silence():
|
||||||
|
previous_segment = PuncSegment.from_tokens(
|
||||||
|
tokens=self.unvalidated_tokens[segment_start_idx: i],
|
||||||
|
)
|
||||||
|
if previous_segment:
|
||||||
|
new_punc_segments.append(previous_segment)
|
||||||
|
segment = PuncSegment.from_tokens(
|
||||||
|
tokens=[token],
|
||||||
|
is_silence=True
|
||||||
|
)
|
||||||
|
new_punc_segments.append(segment)
|
||||||
|
segment_start_idx = i+1
|
||||||
|
else:
|
||||||
|
if token.has_punctuation():
|
||||||
|
segment = PuncSegment.from_tokens(
|
||||||
|
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
|
||||||
|
)
|
||||||
|
new_punc_segments.append(segment)
|
||||||
|
segment_start_idx = i+1
|
||||||
|
|
||||||
|
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
|
||||||
|
return new_punc_segments
|
||||||
|
|
||||||
|
|
||||||
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
||||||
"""Merge consecutive diarization slices that share the same speaker."""
|
"""Merge consecutive diarization slices that share the same speaker."""
|
||||||
@@ -102,8 +141,8 @@ class TokensAlignment:
|
|||||||
|
|
||||||
return max(0, end - start)
|
return max(0, end - start)
|
||||||
|
|
||||||
def get_lines_diarization(self) -> Tuple[List[Line], str]:
|
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
||||||
"""Build lines when diarization is enabled and track overflow buffer."""
|
"""Build segments when diarization is enabled and track overflow buffer."""
|
||||||
diarization_buffer = ''
|
diarization_buffer = ''
|
||||||
punctuation_segments = self.compute_punctuations_segments()
|
punctuation_segments = self.compute_punctuations_segments()
|
||||||
diarization_segments = self.concatenate_diar_segments()
|
diarization_segments = self.concatenate_diar_segments()
|
||||||
@@ -121,18 +160,18 @@ class TokensAlignment:
|
|||||||
max_overlap_speaker = diarization_segment.speaker + 1
|
max_overlap_speaker = diarization_segment.speaker + 1
|
||||||
punctuation_segment.speaker = max_overlap_speaker
|
punctuation_segment.speaker = max_overlap_speaker
|
||||||
|
|
||||||
lines = []
|
segments = []
|
||||||
if punctuation_segments:
|
if punctuation_segments:
|
||||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
segments = [punctuation_segments[0]]
|
||||||
for segment in punctuation_segments[1:]:
|
for segment in punctuation_segments[1:]:
|
||||||
if segment.speaker == lines[-1].speaker:
|
if segment.speaker == segments[-1].speaker:
|
||||||
if lines[-1].text:
|
if segments[-1].text:
|
||||||
lines[-1].text += segment.text
|
segments[-1].text += segment.text
|
||||||
lines[-1].end = segment.end
|
segments[-1].end = segment.end
|
||||||
else:
|
else:
|
||||||
lines.append(Line().build_from_segment(segment))
|
segments.append(segment)
|
||||||
|
|
||||||
return lines, diarization_buffer
|
return segments, diarization_buffer
|
||||||
|
|
||||||
|
|
||||||
def get_lines(
|
def get_lines(
|
||||||
@@ -140,40 +179,42 @@ class TokensAlignment:
|
|||||||
diarization: bool = False,
|
diarization: bool = False,
|
||||||
translation: bool = False,
|
translation: bool = False,
|
||||||
current_silence: Optional[Silence] = None
|
current_silence: Optional[Silence] = None
|
||||||
) -> Tuple[List[Line], str, Union[str, TimedText]]:
|
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
|
||||||
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
|
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
|
||||||
if diarization:
|
if diarization:
|
||||||
lines, diarization_buffer = self.get_lines_diarization()
|
segments, diarization_buffer = self.get_lines_diarization()
|
||||||
else:
|
else:
|
||||||
diarization_buffer = ''
|
diarization_buffer = ''
|
||||||
lines = []
|
for token in self.new_tokens:
|
||||||
current_line_tokens = []
|
if isinstance(token, Silence):
|
||||||
for token in self.all_tokens:
|
if self.current_line_tokens:
|
||||||
if token.is_silence():
|
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
if current_line_tokens:
|
self.current_line_tokens = []
|
||||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
|
||||||
current_line_tokens = []
|
|
||||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||||
if lines and lines[-1].is_silent():
|
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||||
lines[-1].end = end_silence
|
self.validated_segments[-1].end = end_silence
|
||||||
else:
|
else:
|
||||||
lines.append(SilentLine(
|
self.validated_segments.append(SilentSegment(
|
||||||
start = token.start,
|
start=token.start,
|
||||||
end = end_silence
|
end=end_silence
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
current_line_tokens.append(token)
|
self.current_line_tokens.append(token)
|
||||||
if current_line_tokens:
|
|
||||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
segments = list(self.validated_segments)
|
||||||
|
if self.current_line_tokens:
|
||||||
|
segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
|
|
||||||
if current_silence:
|
if current_silence:
|
||||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||||
if lines and lines[-1].is_silent():
|
if segments and segments[-1].is_silence():
|
||||||
lines[-1].end = end_silence
|
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
||||||
else:
|
else:
|
||||||
lines.append(SilentLine(
|
segments.append(SilentSegment(
|
||||||
start = current_silence.start,
|
start=current_silence.start,
|
||||||
end = end_silence
|
end=end_silence
|
||||||
))
|
))
|
||||||
if translation:
|
if translation:
|
||||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||||
return lines, diarization_buffer, self.new_translation_buffer.text
|
return segments, diarization_buffer, self.new_translation_buffer.text
|
||||||
|
|||||||
395
whisperlivekit/voxtral_hf_streaming.py
Normal file
395
whisperlivekit/voxtral_hf_streaming.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
Voxtral Mini Realtime streaming backend using HuggingFace Transformers.
|
||||||
|
|
||||||
|
Uses VoxtralRealtimeForConditionalGeneration with a background generate thread
|
||||||
|
and queue-based audio feeding for real-time streaming transcription.
|
||||||
|
Supports CUDA, CPU, and MPS devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralHFStreamingASR:
|
||||||
|
"""Voxtral model holder using HuggingFace Transformers."""
|
||||||
|
|
||||||
|
sep = " "
|
||||||
|
|
||||||
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
VoxtralRealtimeForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
|
||||||
|
lan = kwargs.get("lan", "auto")
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||||
|
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||||
|
if not model_path:
|
||||||
|
model_size = kwargs.get("model_size", "")
|
||||||
|
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||||
|
model_path = model_size
|
||||||
|
else:
|
||||||
|
model_path = DEFAULT_MODEL
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
logger.info(f"Loading Voxtral model '{model_path}' via HF Transformers...")
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
self.model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
logger.info(f"Voxtral HF model loaded in {time.time() - t:.2f}s on {self.model.device}")
|
||||||
|
|
||||||
|
self.backend_choice = "voxtral"
|
||||||
|
self.tokenizer = None # sentence tokenizer — not needed for streaming
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralHFStreamingOnlineProcessor:
|
||||||
|
"""
|
||||||
|
Online processor for Voxtral streaming ASR via HuggingFace Transformers.
|
||||||
|
|
||||||
|
Uses a background thread running model.generate() with a queue-based
|
||||||
|
input_features_generator and TextIteratorStreamer for real-time output.
|
||||||
|
Each decoded token corresponds to ~80ms of audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(self, asr: VoxtralHFStreamingASR, logfile=sys.stderr):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer = []
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
processor = asr.processor
|
||||||
|
self._first_chunk_samples = processor.num_samples_first_audio_chunk
|
||||||
|
self._chunk_samples = processor.num_samples_per_audio_chunk
|
||||||
|
self._chunk_step = processor.raw_audio_length_per_tok
|
||||||
|
n_right_pad = processor.num_right_pad_tokens
|
||||||
|
if callable(n_right_pad):
|
||||||
|
n_right_pad = n_right_pad()
|
||||||
|
self._right_pad_samples = int(n_right_pad * processor.raw_audio_length_per_tok)
|
||||||
|
self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE
|
||||||
|
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[voxtral-hf] Initialized. first_chunk={self._first_chunk_samples} samples, "
|
||||||
|
f"chunk={self._chunk_samples}, step={self._chunk_step}, "
|
||||||
|
f"right_pad={self._right_pad_samples}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
self._pending_audio = np.zeros(0, dtype=np.float32)
|
||||||
|
self._audio_queue: queue.Queue = queue.Queue()
|
||||||
|
self._streamer_texts: List[str] = []
|
||||||
|
self._generate_thread: Optional[threading.Thread] = None
|
||||||
|
self._generate_started = False
|
||||||
|
self._generate_finished = False
|
||||||
|
self._generate_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
# Text accumulation and word extraction
|
||||||
|
self._accumulated_text = ""
|
||||||
|
self._n_text_tokens_received = 0
|
||||||
|
self._n_committed_words = 0
|
||||||
|
self._global_time_offset = 0.0
|
||||||
|
|
||||||
|
# Lock for text state accessed from both generate thread and main thread
|
||||||
|
self._text_lock = threading.Lock()
|
||||||
|
|
||||||
|
# ── Interface methods ──
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
self._pending_audio = np.append(self._pending_audio, audio)
|
||||||
|
self.audio_buffer = self._pending_audio
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
try:
|
||||||
|
return self._process_iter_inner(is_last)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[voxtral-hf] process_iter exception: {e}", exc_info=True)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def get_buffer(self) -> Transcript:
|
||||||
|
"""Return all uncommitted text as buffer."""
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
uncommitted = words[self._n_committed_words:]
|
||||||
|
if uncommitted:
|
||||||
|
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush all uncommitted words when silence starts."""
|
||||||
|
self._drain_streamer()
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words")
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
|
self._global_time_offset += silence_duration
|
||||||
|
self.end += silence_duration
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
self.start_silence()
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush remaining audio with right-padding and stop the generate thread."""
|
||||||
|
# Add right-padding so the model can finish decoding
|
||||||
|
if self._right_pad_samples > 0:
|
||||||
|
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
|
||||||
|
self._pending_audio = np.append(self._pending_audio, right_pad)
|
||||||
|
|
||||||
|
# Feed remaining audio
|
||||||
|
if self._generate_started and not self._generate_finished:
|
||||||
|
self._feed_pending_audio()
|
||||||
|
# Signal end of audio
|
||||||
|
self._audio_queue.put(None)
|
||||||
|
# Wait for generate to finish
|
||||||
|
if self._generate_thread is not None:
|
||||||
|
self._generate_thread.join(timeout=30.0)
|
||||||
|
elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples:
|
||||||
|
# Never started but have enough audio — start and immediately finish
|
||||||
|
self._start_generate_thread()
|
||||||
|
self._feed_pending_audio()
|
||||||
|
self._audio_queue.put(None)
|
||||||
|
if self._generate_thread is not None:
|
||||||
|
self._generate_thread.join(timeout=30.0)
|
||||||
|
|
||||||
|
self._drain_streamer()
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
logger.info(f"[voxtral-hf] finish: flushed {len(words)} words")
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
# ── Generate thread management ──
|
||||||
|
|
||||||
|
def _start_generate_thread(self):
|
||||||
|
"""Start model.generate() in a background thread with streaming."""
|
||||||
|
import torch
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
processor = self.asr.processor
|
||||||
|
model = self.asr.model
|
||||||
|
|
||||||
|
# Extract first chunk
|
||||||
|
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
|
||||||
|
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
|
||||||
|
|
||||||
|
first_inputs = processor(
|
||||||
|
first_chunk_audio,
|
||||||
|
is_streaming=True,
|
||||||
|
is_first_audio_chunk=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
first_inputs = first_inputs.to(model.device, dtype=model.dtype)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
processor.tokenizer,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
self._streamer = streamer
|
||||||
|
|
||||||
|
audio_queue = self._audio_queue
|
||||||
|
|
||||||
|
def input_features_gen():
|
||||||
|
yield first_inputs.input_features
|
||||||
|
while True:
|
||||||
|
chunk_audio = audio_queue.get()
|
||||||
|
if chunk_audio is None:
|
||||||
|
break
|
||||||
|
inputs = processor(
|
||||||
|
chunk_audio,
|
||||||
|
is_streaming=True,
|
||||||
|
is_first_audio_chunk=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = inputs.to(model.device, dtype=model.dtype)
|
||||||
|
yield inputs.input_features
|
||||||
|
|
||||||
|
def run_generate():
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Pass generator as input_features — the model detects GeneratorType
|
||||||
|
# and internally converts it to input_features_generator
|
||||||
|
generate_kwargs = {
|
||||||
|
k: v for k, v in first_inputs.items()
|
||||||
|
if k != "input_features"
|
||||||
|
}
|
||||||
|
model.generate(
|
||||||
|
input_features=input_features_gen(),
|
||||||
|
streamer=streamer,
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True)
|
||||||
|
self._generate_error = e
|
||||||
|
finally:
|
||||||
|
self._generate_finished = True
|
||||||
|
|
||||||
|
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
|
||||||
|
self._generate_thread.start()
|
||||||
|
self._generate_started = True
|
||||||
|
logger.info("[voxtral-hf] generate thread started")
|
||||||
|
|
||||||
|
def _feed_pending_audio(self):
|
||||||
|
"""Convert pending audio into properly-sized chunks for the generator."""
|
||||||
|
chunk_size = self._chunk_samples
|
||||||
|
step_size = self._chunk_step
|
||||||
|
|
||||||
|
while len(self._pending_audio) >= chunk_size:
|
||||||
|
chunk = self._pending_audio[:chunk_size]
|
||||||
|
self._audio_queue.put(chunk)
|
||||||
|
self._pending_audio = self._pending_audio[step_size:]
|
||||||
|
|
||||||
|
self.audio_buffer = self._pending_audio
|
||||||
|
|
||||||
|
def _drain_streamer(self):
|
||||||
|
"""Non-blocking drain of all available text from the streamer."""
|
||||||
|
if not self._generate_started:
|
||||||
|
return
|
||||||
|
|
||||||
|
text_queue = self._streamer.text_queue
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
text_fragment = text_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
# TextIteratorStreamer uses None as end-of-stream sentinel
|
||||||
|
if text_fragment is None:
|
||||||
|
self._generate_finished = True
|
||||||
|
break
|
||||||
|
if text_fragment:
|
||||||
|
with self._text_lock:
|
||||||
|
self._accumulated_text += text_fragment
|
||||||
|
self._n_text_tokens_received += 1
|
||||||
|
|
||||||
|
# ── Word extraction ──
|
||||||
|
|
||||||
|
def _pos_to_time(self, token_position: int) -> float:
|
||||||
|
"""Convert token position to seconds."""
|
||||||
|
return token_position * self._seconds_per_token + self._global_time_offset
|
||||||
|
|
||||||
|
def _extract_new_words(self) -> List[ASRToken]:
|
||||||
|
"""Extract complete words (all but the last, which may still be growing)."""
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_tokens = self._n_text_tokens_received
|
||||||
|
n_words_total = len(words)
|
||||||
|
|
||||||
|
while len(words) > self._n_committed_words + 1:
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
word_idx = self._n_committed_words
|
||||||
|
|
||||||
|
tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0
|
||||||
|
tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0
|
||||||
|
|
||||||
|
start_time = self._pos_to_time(tok_start)
|
||||||
|
end_time = self._pos_to_time(tok_end)
|
||||||
|
|
||||||
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
def _flush_all_pending_words(self) -> List[ASRToken]:
|
||||||
|
"""Flush ALL words including the last partial one."""
|
||||||
|
with self._text_lock:
|
||||||
|
text = self._accumulated_text
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_tokens = max(self._n_text_tokens_received, 1)
|
||||||
|
n_words_total = max(len(words), 1)
|
||||||
|
|
||||||
|
while self._n_committed_words < len(words):
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
word_idx = self._n_committed_words
|
||||||
|
|
||||||
|
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||||
|
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||||
|
|
||||||
|
start_time = self._pos_to_time(tok_start)
|
||||||
|
end_time = self._pos_to_time(tok_end)
|
||||||
|
|
||||||
|
text_out = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
# ── Core processing ──
|
||||||
|
|
||||||
|
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||||
|
# Start generate thread when enough audio is buffered
|
||||||
|
if not self._generate_started:
|
||||||
|
if len(self._pending_audio) >= self._first_chunk_samples:
|
||||||
|
self._start_generate_thread()
|
||||||
|
self._feed_pending_audio()
|
||||||
|
else:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# Feed any new pending audio
|
||||||
|
if self._generate_started and not self._generate_finished:
|
||||||
|
self._feed_pending_audio()
|
||||||
|
|
||||||
|
# If generate finished unexpectedly (EOS) but new audio arrived, restart
|
||||||
|
if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples:
|
||||||
|
self._drain_streamer()
|
||||||
|
flush_words = self._flush_all_pending_words()
|
||||||
|
# Reset for new utterance
|
||||||
|
old_offset = self._global_time_offset
|
||||||
|
self._reset_state()
|
||||||
|
self._global_time_offset = old_offset
|
||||||
|
self._start_generate_thread()
|
||||||
|
self._feed_pending_audio()
|
||||||
|
return flush_words, self.end
|
||||||
|
|
||||||
|
# Drain available text from streamer
|
||||||
|
self._drain_streamer()
|
||||||
|
|
||||||
|
# Extract complete words
|
||||||
|
new_words = self._extract_new_words()
|
||||||
|
|
||||||
|
if new_words:
|
||||||
|
logger.info(f"[voxtral-hf] returning {len(new_words)} words: {[w.text for w in new_words]}")
|
||||||
|
|
||||||
|
self.buffer = []
|
||||||
|
return new_words, self.end
|
||||||
6
whisperlivekit/voxtral_mlx/__init__.py
Normal file
6
whisperlivekit/voxtral_mlx/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
|
||||||
|
|
||||||
|
from .loader import load_voxtral_model
|
||||||
|
from .model import VoxtralMLXModel
|
||||||
|
|
||||||
|
__all__ = ["load_voxtral_model", "VoxtralMLXModel"]
|
||||||
282
whisperlivekit/voxtral_mlx/loader.py
Normal file
282
whisperlivekit/voxtral_mlx/loader.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
Model weight loading for the MLX Voxtral Realtime backend.
|
||||||
|
|
||||||
|
Supports two on-disk formats:
|
||||||
|
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
|
||||||
|
with optional quantisation metadata.
|
||||||
|
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
|
||||||
|
requires weight renaming and conv-weight transposition.
|
||||||
|
|
||||||
|
The public entry point is :func:`load_voxtral_model` which returns the
|
||||||
|
model, tokenizer, and raw config dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from .model import VoxtralMLXModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Downloading
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_ALLOWED_PATTERNS = [
|
||||||
|
"consolidated.safetensors",
|
||||||
|
"model*.safetensors",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"params.json",
|
||||||
|
"config.json",
|
||||||
|
"tekken.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
|
||||||
|
"""Download model files from HuggingFace Hub and return the local path."""
|
||||||
|
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Weight name remapping (Mistral → our naming)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_NAME_RULES: list[tuple[str, str]] = [
|
||||||
|
# Encoder convolutions
|
||||||
|
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
|
||||||
|
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
|
||||||
|
# Encoder transformer blocks
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
|
||||||
|
r"encoder.blocks.\1.self_attn.q_proj.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
|
||||||
|
r"encoder.blocks.\1.self_attn.k_proj.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
|
||||||
|
r"encoder.blocks.\1.self_attn.v_proj.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
|
||||||
|
r"encoder.blocks.\1.self_attn.out_proj.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
|
||||||
|
r"encoder.blocks.\1.pre_attn_norm.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
|
||||||
|
r"encoder.blocks.\1.ffn.gate.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
|
||||||
|
r"encoder.blocks.\1.ffn.down.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
|
||||||
|
r"encoder.blocks.\1.ffn.up.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
|
||||||
|
r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
||||||
|
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
|
||||||
|
# Adapter
|
||||||
|
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
|
||||||
|
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
|
||||||
|
# Decoder embedding
|
||||||
|
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
|
||||||
|
# Decoder blocks
|
||||||
|
(r"layers\.(\d+)\.attention\.wq\.weight",
|
||||||
|
r"decoder.blocks.\1.self_attn.q_proj.weight"),
|
||||||
|
(r"layers\.(\d+)\.attention\.wk\.weight",
|
||||||
|
r"decoder.blocks.\1.self_attn.k_proj.weight"),
|
||||||
|
(r"layers\.(\d+)\.attention\.wv\.weight",
|
||||||
|
r"decoder.blocks.\1.self_attn.v_proj.weight"),
|
||||||
|
(r"layers\.(\d+)\.attention\.wo\.weight",
|
||||||
|
r"decoder.blocks.\1.self_attn.out_proj.weight"),
|
||||||
|
(r"layers\.(\d+)\.attention_norm\.weight",
|
||||||
|
r"decoder.blocks.\1.pre_attn_norm.weight"),
|
||||||
|
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
|
||||||
|
r"decoder.blocks.\1.ffn.gate.weight"),
|
||||||
|
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
|
||||||
|
r"decoder.blocks.\1.ffn.down.weight"),
|
||||||
|
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
|
||||||
|
r"decoder.blocks.\1.ffn.up.weight"),
|
||||||
|
(r"layers\.(\d+)\.ffn_norm\.weight",
|
||||||
|
r"decoder.blocks.\1.pre_ffn_norm.weight"),
|
||||||
|
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
|
||||||
|
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
|
||||||
|
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
|
||||||
|
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
|
||||||
|
# Decoder final norm
|
||||||
|
(r"norm\.weight", r"decoder.final_norm.weight"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_PREFIX_STRIP = re.compile(
|
||||||
|
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _translate_weight_name(name: str) -> str | None:
|
||||||
|
name = _PREFIX_STRIP.sub("", name)
|
||||||
|
for pattern, replacement in _NAME_RULES:
|
||||||
|
result, n = re.subn(f"^{pattern}$", replacement, name)
|
||||||
|
if n:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_conv_weight(name: str) -> bool:
|
||||||
|
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Converted-format weight remapping (voxmlx names → our names)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CONVERTED_RULES: list[tuple[str, str]] = [
|
||||||
|
# Adapter
|
||||||
|
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
|
||||||
|
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
|
||||||
|
# Encoder transformer blocks
|
||||||
|
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
|
||||||
|
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
|
||||||
|
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
|
||||||
|
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
|
||||||
|
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
|
||||||
|
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
||||||
|
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
|
||||||
|
# Decoder embedding
|
||||||
|
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
|
||||||
|
# Decoder blocks
|
||||||
|
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
|
||||||
|
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
|
||||||
|
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
|
||||||
|
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
|
||||||
|
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also remap o_proj → out_proj in both encoder and decoder
|
||||||
|
_POST_RENAME = [
|
||||||
|
(r"\.o_proj\.", r".out_proj."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _remap_converted_name(name: str) -> str:
|
||||||
|
"""Translate a converted-format weight name to our naming convention."""
|
||||||
|
for pattern, replacement in _CONVERTED_RULES:
|
||||||
|
result, n = re.subn(f"^{pattern}$", replacement, name)
|
||||||
|
if n:
|
||||||
|
name = result
|
||||||
|
break
|
||||||
|
for pattern, replacement in _POST_RENAME:
|
||||||
|
name = re.sub(pattern, replacement, name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Loading strategies
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _has_converted_layout(path: Path) -> bool:
|
||||||
|
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_converted_weights(path: Path):
|
||||||
|
with open(path / "config.json") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model = VoxtralMLXModel(config)
|
||||||
|
|
||||||
|
quant = config.get("quantization")
|
||||||
|
if quant is not None:
|
||||||
|
gs = quant["group_size"]
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=gs,
|
||||||
|
bits=quant["bits"],
|
||||||
|
class_predicate=lambda _p, m: (
|
||||||
|
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
index_file = path / "model.safetensors.index.json"
|
||||||
|
if index_file.exists():
|
||||||
|
with open(index_file) as f:
|
||||||
|
shard_map = json.load(f)
|
||||||
|
shard_files = sorted(set(shard_map["weight_map"].values()))
|
||||||
|
weights = {}
|
||||||
|
for sf in shard_files:
|
||||||
|
weights.update(mx.load(str(path / sf)))
|
||||||
|
else:
|
||||||
|
weights = mx.load(str(path / "model.safetensors"))
|
||||||
|
|
||||||
|
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
|
||||||
|
model.load_weights(list(remapped.items()))
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
def _load_original_weights(path: Path):
|
||||||
|
with open(path / "params.json") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model = VoxtralMLXModel(config)
|
||||||
|
|
||||||
|
raw = mx.load(str(path / "consolidated.safetensors"))
|
||||||
|
mapped: dict[str, mx.array] = {}
|
||||||
|
skipped: list[str] = []
|
||||||
|
|
||||||
|
for name, tensor in raw.items():
|
||||||
|
if name == "output.weight":
|
||||||
|
continue
|
||||||
|
new_name = _translate_weight_name(name)
|
||||||
|
if new_name is None:
|
||||||
|
skipped.append(name)
|
||||||
|
continue
|
||||||
|
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
|
||||||
|
if _is_conv_weight(new_name):
|
||||||
|
tensor = mx.swapaxes(tensor, 1, 2)
|
||||||
|
mapped[new_name] = tensor
|
||||||
|
|
||||||
|
if skipped:
|
||||||
|
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
|
||||||
|
|
||||||
|
model.load_weights(list(mapped.items()))
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tokenizer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_tokenizer(model_dir: Path):
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||||
|
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
|
||||||
|
"""Load a Voxtral Realtime model and its tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_id: Local directory path **or** a HuggingFace model ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(model, tokenizer, config)``
|
||||||
|
"""
|
||||||
|
p = Path(path_or_id)
|
||||||
|
if not p.exists():
|
||||||
|
p = download_weights(path_or_id)
|
||||||
|
|
||||||
|
if _has_converted_layout(p):
|
||||||
|
model, config = _load_converted_weights(p)
|
||||||
|
else:
|
||||||
|
model, config = _load_original_weights(p)
|
||||||
|
|
||||||
|
tokenizer = _load_tokenizer(p)
|
||||||
|
logger.info("Voxtral MLX model loaded from %s", p)
|
||||||
|
return model, tokenizer, config
|
||||||
534
whisperlivekit/voxtral_mlx/model.py
Normal file
534
whisperlivekit/voxtral_mlx/model.py
Normal file
@@ -0,0 +1,534 @@
|
|||||||
|
"""
|
||||||
|
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
|
||||||
|
with DelayEmbedding providing time-conditioning to the decoder.
|
||||||
|
|
||||||
|
The model supports both batch inference (full audio) and incremental streaming
|
||||||
|
(one chunk at a time with cached encoder/decoder state).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KV Cache
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingKVCache:
|
||||||
|
"""Bounded key-value cache with rotating buffer for sliding-window attention.
|
||||||
|
|
||||||
|
Uses in-place writes for single-token autoregressive steps and
|
||||||
|
concatenation for multi-token prefills. Pre-allocates in blocks of
|
||||||
|
``alloc_step`` entries to reduce repeated allocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
alloc_step = 256
|
||||||
|
|
||||||
|
def __init__(self, capacity: int):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self._offset = 0
|
||||||
|
self._write_idx = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def offset(self) -> int:
|
||||||
|
return self._offset
|
||||||
|
|
||||||
|
# -- helpers --
|
||||||
|
|
||||||
|
def _reorder(self, buf):
|
||||||
|
"""Return *buf* in temporal order (unwrap the circular buffer)."""
|
||||||
|
if self._write_idx == buf.shape[2]:
|
||||||
|
return buf
|
||||||
|
if self._write_idx < self._offset:
|
||||||
|
return mx.concatenate(
|
||||||
|
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
return buf[..., : self._write_idx, :]
|
||||||
|
|
||||||
|
def _drop_oldest(self, buf, n_drop, tail=None):
|
||||||
|
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
|
||||||
|
if tail is not None:
|
||||||
|
parts.append(tail)
|
||||||
|
return mx.concatenate(parts, axis=2)
|
||||||
|
|
||||||
|
# -- update strategies --
|
||||||
|
|
||||||
|
def _append_concat(self, k, v):
|
||||||
|
"""Multi-token update via concatenation (used during prefill)."""
|
||||||
|
if self.keys is None:
|
||||||
|
self.keys, self.values = k, v
|
||||||
|
else:
|
||||||
|
self.keys = self._reorder(self.keys)
|
||||||
|
self.values = self._reorder(self.values)
|
||||||
|
self._write_idx = self.keys.shape[2]
|
||||||
|
overflow = self._write_idx - self.capacity + 1
|
||||||
|
self.keys = self._drop_oldest(self.keys, overflow, k)
|
||||||
|
self.values = self._drop_oldest(self.values, overflow, v)
|
||||||
|
self._offset += k.shape[2]
|
||||||
|
self._write_idx = self.keys.shape[2]
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def _write_inplace(self, k, v):
|
||||||
|
"""Single-token update via in-place write (autoregressive step)."""
|
||||||
|
B, n_heads, S, dim_k = k.shape
|
||||||
|
dim_v = v.shape[3]
|
||||||
|
prev = self._offset
|
||||||
|
|
||||||
|
if self.keys is None or (
|
||||||
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
|
||||||
|
):
|
||||||
|
n_new = min(self.alloc_step, self.capacity - prev)
|
||||||
|
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
|
||||||
|
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, fresh_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = fresh_k, fresh_v
|
||||||
|
self._write_idx = prev
|
||||||
|
|
||||||
|
overflow = self.keys.shape[2] - self.capacity
|
||||||
|
if overflow > 0:
|
||||||
|
self.keys = self._drop_oldest(self.keys, overflow)
|
||||||
|
self.values = self._drop_oldest(self.values, overflow)
|
||||||
|
self._write_idx = self.capacity
|
||||||
|
|
||||||
|
if self._write_idx == self.capacity:
|
||||||
|
self._write_idx = 0
|
||||||
|
|
||||||
|
self.keys[..., self._write_idx : self._write_idx + S, :] = k
|
||||||
|
self.values[..., self._write_idx : self._write_idx + S, :] = v
|
||||||
|
self._offset += S
|
||||||
|
self._write_idx += S
|
||||||
|
|
||||||
|
if self._offset < self.capacity:
|
||||||
|
return (
|
||||||
|
self.keys[..., : self._offset, :],
|
||||||
|
self.values[..., : self._offset, :],
|
||||||
|
)
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
# -- public API --
|
||||||
|
|
||||||
|
def update_and_fetch(self, k, v):
|
||||||
|
if k.shape[2] == 1:
|
||||||
|
return self._write_inplace(k, v)
|
||||||
|
return self._append_concat(k, v)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Encoder components
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv(nn.Module):
|
||||||
|
"""1-D causal convolution (left-padded so no future leakage)."""
|
||||||
|
|
||||||
|
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.kernel = kernel
|
||||||
|
self.left_pad = kernel - stride
|
||||||
|
self.weight = mx.zeros((channels_out, kernel, channels_in))
|
||||||
|
self.bias = mx.zeros((channels_out,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
if self.left_pad > 0:
|
||||||
|
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
|
||||||
|
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class _EncoderSelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||||
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache=None):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
pos = cache.offset if cache is not None else 0
|
||||||
|
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||||
|
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
k, v = cache.update_and_fetch(k, v)
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||||
|
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
||||||
|
|
||||||
|
|
||||||
|
class _EncoderFFN(nn.Module):
|
||||||
|
"""SwiGLU feed-forward for encoder layers."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, hidden: int):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(dim, hidden, bias=False)
|
||||||
|
self.up = nn.Linear(dim, hidden, bias=False)
|
||||||
|
self.down = nn.Linear(hidden, dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
||||||
|
|
||||||
|
|
||||||
|
class _EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
|
||||||
|
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
self.ffn = _EncoderFFN(dim, hidden)
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache=None):
|
||||||
|
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
|
||||||
|
x = x + self.ffn(self.pre_ffn_norm(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingEncoder(nn.Module):
|
||||||
|
"""Causal Whisper-style encoder with two causal convolutions followed by
|
||||||
|
a stack of transformer blocks. Supports both full-sequence and
|
||||||
|
incremental (streaming) forward passes."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mel_channels: int = 128,
|
||||||
|
dim: int = 1280,
|
||||||
|
n_layers: int = 32,
|
||||||
|
n_heads: int = 32,
|
||||||
|
head_dim: int = 64,
|
||||||
|
hidden_dim: int = 5120,
|
||||||
|
rope_theta: float = 1e6,
|
||||||
|
sliding_window: int = 750,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
|
||||||
|
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
|
||||||
|
self.blocks = [
|
||||||
|
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
]
|
||||||
|
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# -- full-sequence --
|
||||||
|
|
||||||
|
def _apply_convs(self, mel: mx.array) -> mx.array:
|
||||||
|
x = mel.T[None, :, :] # [1, T, mel_channels]
|
||||||
|
x = nn.gelu(self.conv1(x))
|
||||||
|
x = nn.gelu(self.conv2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, mel: mx.array) -> mx.array:
|
||||||
|
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, mask="causal")
|
||||||
|
return self.final_norm(x)
|
||||||
|
|
||||||
|
# -- incremental (streaming) --
|
||||||
|
|
||||||
|
def forward_conv_incremental(self, x_in, tail1, tail2):
|
||||||
|
"""Process new mel frames through the two causal convs using cached tails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_in: [1, N, mel_channels]
|
||||||
|
tail1: [1, pad1, mel_channels] or None (first call)
|
||||||
|
tail2: [1, pad2, dim] or None (first call)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(out, new_tail1, new_tail2)
|
||||||
|
"""
|
||||||
|
# Conv1 (kernel=3, stride=1 → left_pad=2)
|
||||||
|
if tail1 is not None:
|
||||||
|
c1_in = mx.concatenate([tail1, x_in], axis=1)
|
||||||
|
else:
|
||||||
|
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
|
||||||
|
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
|
||||||
|
c1_out = nn.gelu(
|
||||||
|
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conv2 (kernel=3, stride=2 → left_pad=1)
|
||||||
|
if tail2 is not None:
|
||||||
|
c2_in = mx.concatenate([tail2, c1_out], axis=1)
|
||||||
|
else:
|
||||||
|
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
|
||||||
|
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
|
||||||
|
c2_out = nn.gelu(
|
||||||
|
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
|
||||||
|
)
|
||||||
|
|
||||||
|
return c2_out, new_tail1, new_tail2
|
||||||
|
|
||||||
|
def forward_transformer_incremental(self, x, cache_list):
|
||||||
|
"""Run transformer blocks with per-layer KV caches."""
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x, mask="causal", cache=cache_list[i])
|
||||||
|
return self.final_norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Decoder components
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _DecoderAttention(nn.Module):
|
||||||
|
"""Grouped-query attention for the text decoder."""
|
||||||
|
|
||||||
|
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
pos = cache.offset if cache is not None else 0
|
||||||
|
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||||
|
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
k, v = cache.update_and_fetch(k, v)
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||||
|
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
||||||
|
|
||||||
|
|
||||||
|
class _DecoderFFN(nn.Module):
|
||||||
|
"""SwiGLU feed-forward for decoder layers."""
|
||||||
|
|
||||||
|
def __init__(self, dim, hidden):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(dim, hidden, bias=False)
|
||||||
|
self.up = nn.Linear(dim, hidden, bias=False)
|
||||||
|
self.down = nn.Linear(hidden, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveScaling(nn.Module):
|
||||||
|
"""Small MLP that produces a multiplicative scale from the delay embedding,
|
||||||
|
used to condition the FFN on the streaming delay."""
|
||||||
|
|
||||||
|
def __init__(self, dim, bottleneck):
|
||||||
|
super().__init__()
|
||||||
|
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
|
||||||
|
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, cond):
|
||||||
|
return self.proj_out(nn.gelu(self.proj_in(cond)))
|
||||||
|
|
||||||
|
|
||||||
|
class _DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
|
||||||
|
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
|
||||||
|
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
self.ffn = _DecoderFFN(dim, hidden)
|
||||||
|
|
||||||
|
def __call__(self, x, delay_cond, mask=None, cache=None):
|
||||||
|
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
|
||||||
|
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
|
||||||
|
x = x + self.ffn(scaled)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextDecoder(nn.Module):
|
||||||
|
"""Mistral-style causal language model with adaptive time-conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3072,
|
||||||
|
n_layers: int = 26,
|
||||||
|
n_heads: int = 32,
|
||||||
|
n_kv_heads: int = 8,
|
||||||
|
head_dim: int = 128,
|
||||||
|
hidden_dim: int = 9216,
|
||||||
|
vocab_size: int = 131072,
|
||||||
|
rope_theta: float = 1e6,
|
||||||
|
cond_dim: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||||
|
self.blocks = [
|
||||||
|
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
]
|
||||||
|
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||||
|
|
||||||
|
def embed(self, token_ids: mx.array) -> mx.array:
|
||||||
|
return self.token_embedding(token_ids)
|
||||||
|
|
||||||
|
def __call__(self, x, delay_cond, mask=None, cache=None):
|
||||||
|
delay_cond = delay_cond.astype(x.dtype)
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
blk_cache = cache[i] if cache is not None else None
|
||||||
|
x = blk(x, delay_cond, mask, blk_cache)
|
||||||
|
x = self.final_norm(x)
|
||||||
|
return self.token_embedding.as_linear(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adapter & embeddings
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderToDecoderAdapter(nn.Module):
|
||||||
|
"""Two-layer projection from encoder space to decoder space."""
|
||||||
|
|
||||||
|
def __init__(self, enc_dim: int, dec_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
|
||||||
|
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.linear2(nn.gelu(self.linear1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class DelayEmbedding(nn.Module):
|
||||||
|
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
|
||||||
|
vector for the decoder's adaptive scaling."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 3072, theta: float = 10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
half = dim // 2
|
||||||
|
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
|
||||||
|
self._freqs = freqs
|
||||||
|
|
||||||
|
def __call__(self, delay: mx.array) -> mx.array:
|
||||||
|
t = delay.reshape(-1, 1).astype(mx.float32)
|
||||||
|
angles = t * self._freqs
|
||||||
|
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralMLXModel(nn.Module):
|
||||||
|
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
|
||||||
|
audio_cfg = enc_cfg["audio_encoding_args"]
|
||||||
|
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
|
||||||
|
|
||||||
|
self.encoder = StreamingEncoder(
|
||||||
|
mel_channels=audio_cfg["num_mel_bins"],
|
||||||
|
dim=enc_cfg["dim"],
|
||||||
|
n_layers=enc_cfg["n_layers"],
|
||||||
|
n_heads=enc_cfg["n_heads"],
|
||||||
|
head_dim=enc_cfg["head_dim"],
|
||||||
|
hidden_dim=enc_cfg["hidden_dim"],
|
||||||
|
rope_theta=enc_cfg["rope_theta"],
|
||||||
|
sliding_window=enc_cfg["sliding_window"],
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_input_dim = enc_cfg["dim"] * ds_factor
|
||||||
|
decoder_dim = config["dim"]
|
||||||
|
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
|
||||||
|
|
||||||
|
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
|
||||||
|
|
||||||
|
self.decoder = TextDecoder(
|
||||||
|
dim=decoder_dim,
|
||||||
|
n_layers=config["n_layers"],
|
||||||
|
n_heads=config["n_heads"],
|
||||||
|
n_kv_heads=config["n_kv_heads"],
|
||||||
|
head_dim=config["head_dim"],
|
||||||
|
hidden_dim=config["hidden_dim"],
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
rope_theta=config["rope_theta"],
|
||||||
|
cond_dim=cond_bottleneck,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
|
||||||
|
self.ds_factor = ds_factor
|
||||||
|
|
||||||
|
# -- batch encode --
|
||||||
|
|
||||||
|
def encode(self, mel: mx.array) -> mx.array:
|
||||||
|
T = mel.shape[1]
|
||||||
|
if T % 2 != 0:
|
||||||
|
mel = mel[:, 1:]
|
||||||
|
|
||||||
|
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
|
||||||
|
h = h[0]
|
||||||
|
|
||||||
|
n = h.shape[0]
|
||||||
|
trim = n % self.ds_factor
|
||||||
|
if trim:
|
||||||
|
h = h[trim:]
|
||||||
|
n = h.shape[0]
|
||||||
|
|
||||||
|
h = h.reshape(n // self.ds_factor, -1)
|
||||||
|
return self.adapter(h)
|
||||||
|
|
||||||
|
# -- incremental encode --
|
||||||
|
|
||||||
|
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
|
||||||
|
"""Incrementally encode new mel frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
|
||||||
|
"""
|
||||||
|
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
|
||||||
|
|
||||||
|
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
|
||||||
|
|
||||||
|
if enc_cache is None:
|
||||||
|
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
|
||||||
|
|
||||||
|
x = self.encoder.forward_transformer_incremental(x, enc_cache)
|
||||||
|
x = x[0] # [N, enc_dim]
|
||||||
|
|
||||||
|
if ds_remainder is not None:
|
||||||
|
x = mx.concatenate([ds_remainder, x])
|
||||||
|
|
||||||
|
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
|
||||||
|
if n_full == 0:
|
||||||
|
return None, conv_tail1, conv_tail2, enc_cache, x
|
||||||
|
|
||||||
|
leftover = x[n_full:] if x.shape[0] > n_full else None
|
||||||
|
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
|
||||||
|
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
|
||||||
|
|
||||||
|
# -- decode --
|
||||||
|
|
||||||
|
def decode(self, embeddings, delay_cond, mask=None, cache=None):
|
||||||
|
return self.decoder(embeddings, delay_cond, mask, cache)
|
||||||
202
whisperlivekit/voxtral_mlx/spectrogram.py
Normal file
202
whisperlivekit/voxtral_mlx/spectrogram.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
Mel spectrogram computation for Voxtral Realtime.
|
||||||
|
|
||||||
|
Provides both a full-audio function and an incremental streaming variant
|
||||||
|
that maintains overlap state between calls. The DFT is computed via
|
||||||
|
matrix multiplication in MLX — no external FFT dependency required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Audio / mel constants matching the Voxtral Realtime model expectations.
|
||||||
|
SAMPLE_RATE = 16_000
|
||||||
|
WINDOW_SIZE = 400 # n_fft
|
||||||
|
HOP = 160
|
||||||
|
MEL_BANDS = 128
|
||||||
|
MEL_MAX = 1.5 # global log-mel normalisation ceiling
|
||||||
|
# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
|
||||||
|
SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
|
||||||
|
|
||||||
|
# Padding tokens used by the model prompt structure.
|
||||||
|
LEFT_PAD_TOKENS = 32
|
||||||
|
RIGHT_PAD_TOKENS = 17
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Slaney mel filterbank
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_slaney_filterbank(
|
||||||
|
sr: int = SAMPLE_RATE,
|
||||||
|
n_fft: int = WINDOW_SIZE,
|
||||||
|
n_mels: int = MEL_BANDS,
|
||||||
|
lo_hz: float = 0.0,
|
||||||
|
hi_hz: float = 8000.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Compute a Slaney-normalised triangular mel filterbank.
|
||||||
|
|
||||||
|
Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _hz2mel(f):
|
||||||
|
threshold = 1000.0
|
||||||
|
base_mel = 15.0
|
||||||
|
log_coeff = 27.0 / np.log(6.4)
|
||||||
|
mel = 3.0 * f / 200.0
|
||||||
|
if isinstance(f, np.ndarray):
|
||||||
|
above = f >= threshold
|
||||||
|
mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
|
||||||
|
elif f >= threshold:
|
||||||
|
mel = base_mel + np.log(f / threshold) * log_coeff
|
||||||
|
return mel
|
||||||
|
|
||||||
|
def _mel2hz(m):
|
||||||
|
threshold = 1000.0
|
||||||
|
base_mel = 15.0
|
||||||
|
log_coeff = np.log(6.4) / 27.0
|
||||||
|
hz = 200.0 * m / 3.0
|
||||||
|
above = m >= base_mel
|
||||||
|
hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
|
||||||
|
return hz
|
||||||
|
|
||||||
|
n_bins = n_fft // 2 + 1
|
||||||
|
fft_hz = np.linspace(0, sr / 2, n_bins)
|
||||||
|
mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
|
||||||
|
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
|
||||||
|
hz_pts = _mel2hz(mel_pts)
|
||||||
|
diffs = np.diff(hz_pts)
|
||||||
|
|
||||||
|
slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
|
||||||
|
rising = -slopes[:, :-2] / diffs[:-1]
|
||||||
|
falling = slopes[:, 2:] / diffs[1:]
|
||||||
|
fb = np.maximum(0.0, np.minimum(rising, falling))
|
||||||
|
|
||||||
|
# Slaney area normalisation
|
||||||
|
widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
|
||||||
|
fb *= np.expand_dims(widths, 0)
|
||||||
|
return fb.T.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
_CACHED_FILTERS: mx.array | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_filters() -> mx.array:
|
||||||
|
global _CACHED_FILTERS
|
||||||
|
if _CACHED_FILTERS is None:
|
||||||
|
_CACHED_FILTERS = mx.array(_build_slaney_filterbank())
|
||||||
|
return _CACHED_FILTERS
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DFT helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _hann_window() -> mx.array:
|
||||||
|
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def _dft_matrices():
|
||||||
|
"""Pre-compute the real / imaginary DFT basis matrices."""
|
||||||
|
n_bins = WINDOW_SIZE // 2 + 1
|
||||||
|
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
|
||||||
|
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
|
||||||
|
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
|
||||||
|
return mx.cos(phase), mx.sin(phase)
|
||||||
|
|
||||||
|
|
||||||
|
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
|
||||||
|
"""Frame *audio* using the Hann window and compute power spectrogram."""
|
||||||
|
n_bins = WINDOW_SIZE // 2 + 1
|
||||||
|
n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
|
||||||
|
if n_frames <= 0:
|
||||||
|
return mx.zeros((0, n_bins))
|
||||||
|
|
||||||
|
offsets = (mx.arange(n_frames) * HOP)[:, None]
|
||||||
|
indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
|
||||||
|
windowed = audio[indices] * window[None, :]
|
||||||
|
|
||||||
|
dft_re, dft_im = _dft_matrices()
|
||||||
|
real_part = windowed @ dft_re.T
|
||||||
|
imag_part = windowed @ dft_im.T
|
||||||
|
return real_part ** 2 + imag_part ** 2
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_mel_and_log(power: mx.array) -> mx.array:
|
||||||
|
"""Convert a power spectrogram to log-mel and normalise."""
|
||||||
|
mel = power @ _mel_filters().T
|
||||||
|
log_mel = mx.log10(mx.maximum(mel, 1e-10))
|
||||||
|
log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
|
||||||
|
return (log_mel + 4.0) / 4.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def compute_mel(audio: np.ndarray) -> mx.array:
|
||||||
|
"""Compute log-mel spectrogram for a complete audio signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``[MEL_BANDS, T]`` MLX array.
|
||||||
|
"""
|
||||||
|
x = mx.array(audio)
|
||||||
|
pad = WINDOW_SIZE // 2
|
||||||
|
x = mx.pad(x, [(pad, pad)])
|
||||||
|
window = _hann_window()
|
||||||
|
|
||||||
|
power = _stft_frames(x, window)
|
||||||
|
# Drop last frame to match reference STFT behaviour
|
||||||
|
power = power[:-1]
|
||||||
|
return _apply_mel_and_log(power).T
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mel_streaming(
|
||||||
|
chunk: np.ndarray,
|
||||||
|
overlap: np.ndarray | None,
|
||||||
|
) -> tuple[mx.array, np.ndarray]:
|
||||||
|
"""Incrementally compute log-mel for a new audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: New audio samples (float32 numpy).
|
||||||
|
overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
|
||||||
|
previous call, or *None* on the first call (uses zero-padding).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
|
||||||
|
*new_overlap* is the 240-sample tail for the next call.
|
||||||
|
"""
|
||||||
|
tail_len = WINDOW_SIZE - HOP # 240
|
||||||
|
|
||||||
|
if overlap is not None:
|
||||||
|
combined = np.concatenate([overlap, chunk])
|
||||||
|
else:
|
||||||
|
combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
|
||||||
|
|
||||||
|
new_overlap = combined[-tail_len:].copy()
|
||||||
|
|
||||||
|
x = mx.array(combined)
|
||||||
|
window = _hann_window()
|
||||||
|
power = _stft_frames(x, window)
|
||||||
|
|
||||||
|
if power.shape[0] == 0:
|
||||||
|
return mx.zeros((MEL_BANDS, 0)), new_overlap
|
||||||
|
|
||||||
|
return _apply_mel_and_log(power).T, new_overlap
|
||||||
|
|
||||||
|
|
||||||
|
def pad_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
n_left: int = LEFT_PAD_TOKENS,
|
||||||
|
n_right: int = RIGHT_PAD_TOKENS,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Pad audio with silence for batch (non-streaming) inference."""
|
||||||
|
left = n_left * SAMPLES_PER_TOKEN
|
||||||
|
align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
|
||||||
|
right = align + n_right * SAMPLES_PER_TOKEN
|
||||||
|
return np.pad(audio, (left, right))
|
||||||
521
whisperlivekit/voxtral_mlx_asr.py
Normal file
521
whisperlivekit/voxtral_mlx_asr.py
Normal file
@@ -0,0 +1,521 @@
|
|||||||
|
"""
|
||||||
|
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
|
||||||
|
|
||||||
|
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
|
||||||
|
(streaming processor) that plug into WhisperLiveKit's audio processing
|
||||||
|
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
|
||||||
|
|
||||||
|
Unlike the HuggingFace backend, this runs the full inference loop in-process
|
||||||
|
(no background thread / queue) — MLX operations on Apple Silicon are fast
|
||||||
|
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
|
||||||
|
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
||||||
|
from whisperlivekit.voxtral_mlx.spectrogram import (
|
||||||
|
SAMPLES_PER_TOKEN,
|
||||||
|
LEFT_PAD_TOKENS,
|
||||||
|
RIGHT_PAD_TOKENS,
|
||||||
|
compute_mel_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Decoder sliding-window size (matches the model's training configuration).
|
||||||
|
_DECODER_WINDOW = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
|
||||||
|
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
|
||||||
|
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
|
||||||
|
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
|
||||||
|
return ids, n_delay
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model holder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralMLXASR:
|
||||||
|
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
|
||||||
|
it alive for the lifetime of the server."""
|
||||||
|
|
||||||
|
sep = " "
|
||||||
|
SAMPLING_RATE = 16_000
|
||||||
|
|
||||||
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
|
||||||
|
lan = kwargs.get("lan", "auto")
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
|
||||||
|
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||||
|
if not model_path:
|
||||||
|
model_size = kwargs.get("model_size", "")
|
||||||
|
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||||
|
model_path = model_size
|
||||||
|
else:
|
||||||
|
model_path = DEFAULT_MODEL_ID
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
|
||||||
|
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
|
||||||
|
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
|
||||||
|
|
||||||
|
self.backend_choice = "voxtral-mlx"
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
pass # all work happens in the online processor
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Online processor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralMLXOnlineProcessor:
|
||||||
|
"""Streaming processor that incrementally encodes audio and decodes text
|
||||||
|
using the MLX Voxtral model.
|
||||||
|
|
||||||
|
Lifecycle (called by ``AudioProcessor.transcription_processor``):
|
||||||
|
|
||||||
|
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
|
||||||
|
... repeat ...
|
||||||
|
start_silence() / end_silence()
|
||||||
|
finish()
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAMPLING_RATE = 16_000
|
||||||
|
|
||||||
|
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer: list = []
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
self._model = asr.model
|
||||||
|
self._tokenizer = asr.tokenizer
|
||||||
|
|
||||||
|
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
|
||||||
|
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
|
||||||
|
self._prefix_len = len(self._prompt_ids)
|
||||||
|
|
||||||
|
self._delay_cond = self._model.delay_embedding(
|
||||||
|
mx.array([self._n_delay], dtype=mx.float32)
|
||||||
|
)
|
||||||
|
mx.eval(self._delay_cond)
|
||||||
|
|
||||||
|
self._prompt_embeds = self._model.decoder.embed(
|
||||||
|
mx.array([self._prompt_ids])
|
||||||
|
)[0] # [prefix_len, dim]
|
||||||
|
mx.eval(self._prompt_embeds)
|
||||||
|
|
||||||
|
self._eos_id = self._tokenizer.eos_id
|
||||||
|
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
|
||||||
|
# The streaming model has an inherent delay: text for audio at position P
|
||||||
|
# is generated at decoder position P + n_delay. Compensate timestamps.
|
||||||
|
self._delay_secs = self._n_delay * self._secs_per_token
|
||||||
|
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
# -- state management --
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
"""Reset all incremental state for a fresh utterance."""
|
||||||
|
# Audio accumulation
|
||||||
|
self._pending = np.zeros(0, dtype=np.float32)
|
||||||
|
# Mel overlap
|
||||||
|
self._mel_overlap: np.ndarray | None = None
|
||||||
|
# Encoder incremental state
|
||||||
|
self._conv_tail1 = None
|
||||||
|
self._conv_tail2 = None
|
||||||
|
self._enc_cache = None
|
||||||
|
self._ds_remainder = None
|
||||||
|
# Audio embeddings not yet decoded
|
||||||
|
self._audio_embeds: mx.array | None = None
|
||||||
|
# Decoder state
|
||||||
|
self._dec_cache: list[SlidingKVCache] | None = None
|
||||||
|
self._last_token: mx.array | None = None
|
||||||
|
# Bookkeeping
|
||||||
|
self._samples_encoded = 0
|
||||||
|
self._positions_decoded = 0
|
||||||
|
self._prefilled = False
|
||||||
|
self._first_chunk = True
|
||||||
|
# Text state
|
||||||
|
self._full_text = ""
|
||||||
|
self._n_text_tokens = 0
|
||||||
|
self._n_committed_words = 0
|
||||||
|
self._time_offset = 0.0
|
||||||
|
# Per-word audio position tracking: decoder position (relative to prefix)
|
||||||
|
# where each word in _full_text started and ended
|
||||||
|
self._word_audio_starts: list[int] = [] # audio pos where word i started
|
||||||
|
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
|
||||||
|
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
|
||||||
|
|
||||||
|
# -- audio ingestion --
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
self._pending = np.append(self._pending, audio)
|
||||||
|
self.audio_buffer = self._pending
|
||||||
|
|
||||||
|
# -- core processing --
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
try:
|
||||||
|
return self._step(is_last)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||||
|
# 1. Encode any new audio
|
||||||
|
self._encode_pending()
|
||||||
|
|
||||||
|
if self._audio_embeds is None:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# 2. Compute how many positions we can safely decode
|
||||||
|
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
|
||||||
|
n_available = self._audio_embeds.shape[0]
|
||||||
|
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
||||||
|
|
||||||
|
if n_decodable <= 0:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# 3. Prefill if needed
|
||||||
|
if not self._prefilled:
|
||||||
|
if self._positions_decoded + n_available < self._prefix_len:
|
||||||
|
return [], self.end
|
||||||
|
self._do_prefill()
|
||||||
|
# Re-check after consuming prefix embeddings
|
||||||
|
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
|
||||||
|
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
||||||
|
|
||||||
|
if n_decodable <= 0 or self._audio_embeds is None:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# 4. Decode available positions
|
||||||
|
hit_eos = self._decode_positions(n_decodable)
|
||||||
|
|
||||||
|
if hit_eos:
|
||||||
|
# Flush words, reset for next utterance
|
||||||
|
words = self._flush_all_words()
|
||||||
|
logger.debug(
|
||||||
|
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
|
||||||
|
"samples_encoded=%d (%.2fs), text='%s'",
|
||||||
|
len(words), self._samples_encoded,
|
||||||
|
self._samples_encoded / self.SAMPLING_RATE,
|
||||||
|
self._full_text[-60:] if self._full_text else "",
|
||||||
|
)
|
||||||
|
saved_offset = self._time_offset
|
||||||
|
self._reset_state()
|
||||||
|
self._time_offset = saved_offset
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
# 5. Extract committed words (all but the last, which may still grow)
|
||||||
|
return self._extract_committed_words(), self.end
|
||||||
|
|
||||||
|
def _encode_pending(self):
|
||||||
|
"""Feed pending audio through the incremental encoder."""
|
||||||
|
available = len(self._pending)
|
||||||
|
if available < SAMPLES_PER_TOKEN:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._first_chunk:
|
||||||
|
# First chunk: prepend silence for left-padding
|
||||||
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||||
|
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
|
||||||
|
chunk = np.concatenate([left_pad, self._pending[:n_take]])
|
||||||
|
self._pending = self._pending[n_take:]
|
||||||
|
self._samples_encoded += n_take
|
||||||
|
self._first_chunk = False
|
||||||
|
else:
|
||||||
|
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||||
|
chunk = self._pending[:n_take]
|
||||||
|
self._pending = self._pending[n_take:]
|
||||||
|
self._samples_encoded += n_take
|
||||||
|
|
||||||
|
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
|
||||||
|
|
||||||
|
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
|
||||||
|
self._model.encode_incremental(
|
||||||
|
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if embeds is not None:
|
||||||
|
mx.eval(embeds)
|
||||||
|
if self._audio_embeds is not None:
|
||||||
|
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
|
||||||
|
else:
|
||||||
|
self._audio_embeds = embeds
|
||||||
|
|
||||||
|
self.audio_buffer = self._pending
|
||||||
|
|
||||||
|
def _do_prefill(self):
|
||||||
|
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
|
||||||
|
n_dec_layers = len(self._model.decoder.blocks)
|
||||||
|
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
|
||||||
|
|
||||||
|
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
|
||||||
|
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
|
||||||
|
|
||||||
|
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
|
||||||
|
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
|
||||||
|
|
||||||
|
self._last_token = self._sample(logits)
|
||||||
|
mx.async_eval(self._last_token)
|
||||||
|
|
||||||
|
# Remove consumed prefix embeddings
|
||||||
|
self._audio_embeds = self._audio_embeds[self._prefix_len :]
|
||||||
|
if self._audio_embeds.shape[0] == 0:
|
||||||
|
self._audio_embeds = None
|
||||||
|
self._positions_decoded = self._prefix_len
|
||||||
|
self._prefilled = True
|
||||||
|
|
||||||
|
def _decode_positions(self, n: int) -> bool:
|
||||||
|
"""Autoregressively decode *n* positions. Returns True on EOS."""
|
||||||
|
base_pos = self._positions_decoded # absolute position before this batch
|
||||||
|
for i in range(n):
|
||||||
|
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
|
||||||
|
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
|
||||||
|
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
|
||||||
|
next_tok = self._sample(logits)
|
||||||
|
mx.async_eval(next_tok)
|
||||||
|
|
||||||
|
token_id = self._last_token.item()
|
||||||
|
if token_id == self._eos_id:
|
||||||
|
# Close the current word if one is being built
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
self._word_audio_ends.append(base_pos + i - self._prefix_len)
|
||||||
|
self._current_word_pos = None
|
||||||
|
self._trim_embeds(i)
|
||||||
|
self._positions_decoded += i
|
||||||
|
return True
|
||||||
|
|
||||||
|
text = self._tokenizer.decode(
|
||||||
|
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
audio_pos = base_pos + i - self._prefix_len
|
||||||
|
|
||||||
|
# Detect word boundary: new word starts with space or is the very first text
|
||||||
|
if text.lstrip() != text or not self._full_text:
|
||||||
|
# Close previous word if exists
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
self._word_audio_ends.append(audio_pos)
|
||||||
|
# Start new word
|
||||||
|
self._word_audio_starts.append(audio_pos)
|
||||||
|
self._current_word_pos = audio_pos
|
||||||
|
elif self._current_word_pos is None:
|
||||||
|
# First token of first word (no leading space)
|
||||||
|
self._word_audio_starts.append(audio_pos)
|
||||||
|
self._current_word_pos = audio_pos
|
||||||
|
|
||||||
|
self._full_text += text
|
||||||
|
self._n_text_tokens += 1
|
||||||
|
|
||||||
|
if i > 0 and i % 256 == 0:
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
self._last_token = next_tok
|
||||||
|
|
||||||
|
self._positions_decoded += n
|
||||||
|
self._trim_embeds(n)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _trim_embeds(self, n_consumed: int):
|
||||||
|
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
|
||||||
|
self._audio_embeds = self._audio_embeds[n_consumed:]
|
||||||
|
else:
|
||||||
|
self._audio_embeds = None
|
||||||
|
|
||||||
|
def _sample(self, logits: mx.array) -> mx.array:
|
||||||
|
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||||
|
|
||||||
|
# -- word extraction --
|
||||||
|
|
||||||
|
def _audio_pos_to_time(self, pos: int) -> float:
|
||||||
|
"""Convert an audio position (relative to prefix end) to seconds."""
|
||||||
|
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
|
||||||
|
|
||||||
|
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
|
||||||
|
"""Compute (start, end) time for a word using tracked word positions."""
|
||||||
|
starts = self._word_audio_starts
|
||||||
|
ends = self._word_audio_ends
|
||||||
|
|
||||||
|
if not starts:
|
||||||
|
return self._time_offset, self._time_offset
|
||||||
|
|
||||||
|
# Get start position for this word
|
||||||
|
if word_idx < len(starts):
|
||||||
|
t0 = self._audio_pos_to_time(starts[word_idx])
|
||||||
|
else:
|
||||||
|
# Fallback: estimate from last known position
|
||||||
|
last_pos = ends[-1] if ends else starts[-1]
|
||||||
|
t0 = self._audio_pos_to_time(last_pos + 1)
|
||||||
|
|
||||||
|
# Get end position: use the start of the next word, or the end of this word
|
||||||
|
if word_idx + 1 < len(starts):
|
||||||
|
t1 = self._audio_pos_to_time(starts[word_idx + 1])
|
||||||
|
elif word_idx < len(ends):
|
||||||
|
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
|
||||||
|
else:
|
||||||
|
# Last word, still being built: use last known position + 1 token
|
||||||
|
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
|
||||||
|
t1 = self._audio_pos_to_time(last_pos + 1)
|
||||||
|
|
||||||
|
return t0, t1
|
||||||
|
|
||||||
|
def _extract_committed_words(self) -> List[ASRToken]:
|
||||||
|
"""Return complete words (all except the last which may still grow)."""
|
||||||
|
if not self._full_text:
|
||||||
|
return []
|
||||||
|
words = self._full_text.split()
|
||||||
|
tokens: List[ASRToken] = []
|
||||||
|
n_total = max(len(words), 1)
|
||||||
|
|
||||||
|
while len(words) > self._n_committed_words + 1:
|
||||||
|
w = words[self._n_committed_words]
|
||||||
|
idx = self._n_committed_words
|
||||||
|
t0, t1 = self._word_time_range(idx, n_total)
|
||||||
|
label = w if idx == 0 else " " + w
|
||||||
|
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _flush_all_words(self) -> List[ASRToken]:
|
||||||
|
"""Flush every word including the last partial one."""
|
||||||
|
if not self._full_text:
|
||||||
|
return []
|
||||||
|
words = self._full_text.split()
|
||||||
|
tokens: List[ASRToken] = []
|
||||||
|
n_total = max(len(words), 1)
|
||||||
|
|
||||||
|
while self._n_committed_words < len(words):
|
||||||
|
w = words[self._n_committed_words]
|
||||||
|
idx = self._n_committed_words
|
||||||
|
t0, t1 = self._word_time_range(idx, n_total)
|
||||||
|
label = w if idx == 0 else " " + w
|
||||||
|
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# -- interface methods --
|
||||||
|
|
||||||
|
def get_buffer(self) -> Transcript:
|
||||||
|
if not self._full_text:
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
words = self._full_text.split()
|
||||||
|
remaining = words[self._n_committed_words :]
|
||||||
|
if remaining:
|
||||||
|
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
words = self._flush_all_words()
|
||||||
|
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
|
self._time_offset += silence_duration
|
||||||
|
self.end += silence_duration
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
self.start_silence()
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
logger.debug(
|
||||||
|
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
|
||||||
|
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
|
||||||
|
len(self._pending),
|
||||||
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||||
|
self._samples_encoded,
|
||||||
|
self._positions_decoded,
|
||||||
|
self._prefilled,
|
||||||
|
self._full_text[-80:] if self._full_text else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
||||||
|
remainder = len(self._pending) % SAMPLES_PER_TOKEN
|
||||||
|
if remainder > 0:
|
||||||
|
align_pad = SAMPLES_PER_TOKEN - remainder
|
||||||
|
else:
|
||||||
|
align_pad = 0
|
||||||
|
|
||||||
|
# Add alignment + right-padding silence
|
||||||
|
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
||||||
|
if total_pad > 0:
|
||||||
|
self._pending = np.append(
|
||||||
|
self._pending, np.zeros(total_pad, dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode remaining audio (including right-padding)
|
||||||
|
self._encode_pending()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
||||||
|
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||||
|
len(self._pending),
|
||||||
|
)
|
||||||
|
|
||||||
|
hit_eos = False
|
||||||
|
|
||||||
|
# Decode everything that's left from right-padding
|
||||||
|
if self._audio_embeds is not None and self._prefilled:
|
||||||
|
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
|
||||||
|
logger.debug(
|
||||||
|
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
|
||||||
|
hit_eos, self._full_text[-80:] if self._full_text else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flush last token if it wasn't EOS
|
||||||
|
if self._last_token is not None:
|
||||||
|
tid = self._last_token.item()
|
||||||
|
if tid != self._eos_id:
|
||||||
|
text = self._tokenizer.decode(
|
||||||
|
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
last_pos = self._positions_decoded - self._prefix_len
|
||||||
|
# Check if this starts a new word
|
||||||
|
if text.lstrip() != text or not self._full_text:
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
self._word_audio_ends.append(last_pos)
|
||||||
|
self._word_audio_starts.append(last_pos)
|
||||||
|
self._current_word_pos = last_pos
|
||||||
|
elif self._current_word_pos is None:
|
||||||
|
self._word_audio_starts.append(last_pos)
|
||||||
|
self._current_word_pos = last_pos
|
||||||
|
self._full_text += text
|
||||||
|
self._n_text_tokens += 1
|
||||||
|
|
||||||
|
# Close the last word if still open
|
||||||
|
if self._current_word_pos is not None:
|
||||||
|
last_pos = self._positions_decoded - self._prefix_len
|
||||||
|
self._word_audio_ends.append(last_pos)
|
||||||
|
self._current_word_pos = None
|
||||||
|
|
||||||
|
words = self._flush_all_words()
|
||||||
|
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
|
||||||
|
return words, self.end
|
||||||
@@ -108,7 +108,7 @@ def available_models() -> List[str]:
|
|||||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||||
"""
|
"""
|
||||||
attempt to infer ModelDimensions from a HF style config.json located
|
attempt to infer ModelDimensions from a HF style config.json located
|
||||||
next to the given checkpoint, usefull for distilled models
|
next to the given checkpoint, usefull for distilled models/MLX models.
|
||||||
"""
|
"""
|
||||||
candidates = []
|
candidates = []
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
|||||||
with open(candidate, "r", encoding="utf-8") as f:
|
with open(candidate, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
# native Whisper format
|
||||||
|
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
|
||||||
|
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
|
||||||
|
"n_text_head", "n_text_layer"]
|
||||||
|
if all(k in config for k in native_keys):
|
||||||
|
return ModelDimensions(
|
||||||
|
n_mels=config["n_mels"],
|
||||||
|
n_audio_ctx=config["n_audio_ctx"],
|
||||||
|
n_audio_state=config["n_audio_state"],
|
||||||
|
n_audio_head=config["n_audio_head"],
|
||||||
|
n_audio_layer=config["n_audio_layer"],
|
||||||
|
n_vocab=config["n_vocab"],
|
||||||
|
n_text_ctx=config["n_text_ctx"],
|
||||||
|
n_text_state=config["n_text_state"],
|
||||||
|
n_text_head=config["n_text_head"],
|
||||||
|
n_text_layer=config["n_text_layer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# HuggingFace format
|
||||||
try:
|
try:
|
||||||
return ModelDimensions(
|
return ModelDimensions(
|
||||||
n_mels=config["num_mel_bins"],
|
n_mels=config["num_mel_bins"],
|
||||||
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
|
|||||||
return converted if converted else state_dict
|
return converted if converted else state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Converts an mlx whisper checkpoint to a default openai whisper one
|
||||||
|
"""
|
||||||
|
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
converted = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "alignment_heads":
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
|
||||||
|
converted[new_key] = value
|
||||||
|
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
def _load_lora_state(lora_path: str):
|
def _load_lora_state(lora_path: str):
|
||||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||||
@@ -264,10 +301,50 @@ def _collapse_hf_module_name(module: str):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
|
||||||
|
|
||||||
|
If lora_path is a local directory containing adapter files, returns it as-is.
|
||||||
|
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
|
||||||
|
"""
|
||||||
|
if not lora_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's already a valid local path
|
||||||
|
if os.path.isdir(lora_path):
|
||||||
|
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||||
|
if os.path.isfile(config_path):
|
||||||
|
return lora_path
|
||||||
|
|
||||||
|
# Try to download from HuggingFace Hub
|
||||||
|
if "/" in lora_path:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
local_path = snapshot_download(
|
||||||
|
repo_id=lora_path,
|
||||||
|
allow_patterns=["adapter_config.json", "adapter_model.*"],
|
||||||
|
)
|
||||||
|
return local_path
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||||
if not lora_path:
|
if not lora_path:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Resolve path (handles HuggingFace Hub download)
|
||||||
|
lora_path = _resolve_lora_path(lora_path)
|
||||||
|
if not lora_path:
|
||||||
|
return
|
||||||
|
|
||||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||||
if not os.path.isfile(config_path):
|
if not os.path.isfile(config_path):
|
||||||
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
||||||
@@ -319,6 +396,75 @@ def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str])
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_checkpoint(
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: str,
|
||||||
|
in_memory: bool = False,
|
||||||
|
checkpoint_bytes: Optional[bytes] = None,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Load a checkpoint from a single file.
|
||||||
|
|
||||||
|
Handles .pt, .bin, and .safetensors formats.
|
||||||
|
"""
|
||||||
|
if checkpoint_bytes is not None:
|
||||||
|
with io.BytesIO(checkpoint_bytes) as fp:
|
||||||
|
return torch.load(fp, map_location=device)
|
||||||
|
|
||||||
|
file_path = Path(file_path)
|
||||||
|
suffix = file_path.suffix.lower()
|
||||||
|
|
||||||
|
if suffix == '.safetensors':
|
||||||
|
try:
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
|
||||||
|
)
|
||||||
|
return load_file(str(file_path), device=device)
|
||||||
|
else:
|
||||||
|
if in_memory:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
checkpoint_bytes = f.read()
|
||||||
|
with io.BytesIO(checkpoint_bytes) as fp:
|
||||||
|
return torch.load(fp, map_location=device)
|
||||||
|
else:
|
||||||
|
with open(file_path, "rb") as fp:
|
||||||
|
return torch.load(fp, map_location=device)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sharded_checkpoint(
|
||||||
|
shard_files: List[Path],
|
||||||
|
device: str,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Load a sharded checkpoint (multiple .safetensors or .bin files).
|
||||||
|
|
||||||
|
Merges all shards into a single state dict.
|
||||||
|
"""
|
||||||
|
merged_state_dict = {}
|
||||||
|
first_suffix = shard_files[0].suffix.lower()
|
||||||
|
|
||||||
|
if first_suffix == '.safetensors':
|
||||||
|
try:
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
|
||||||
|
)
|
||||||
|
for shard_path in shard_files:
|
||||||
|
shard_dict = load_file(str(shard_path), device=device)
|
||||||
|
merged_state_dict.update(shard_dict)
|
||||||
|
else:
|
||||||
|
for shard_path in shard_files:
|
||||||
|
with open(shard_path, "rb") as fp:
|
||||||
|
shard_dict = torch.load(fp, map_location=device)
|
||||||
|
if isinstance(shard_dict, dict):
|
||||||
|
merged_state_dict.update(shard_dict)
|
||||||
|
|
||||||
|
return merged_state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
@@ -336,6 +482,8 @@ def load_model(
|
|||||||
name : str
|
name : str
|
||||||
one of the official model names listed by `whisper.available_models()`, or
|
one of the official model names listed by `whisper.available_models()`, or
|
||||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||||
|
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
|
||||||
|
or a sharded model directory with files like model-00001-of-00002.safetensors.
|
||||||
device : Union[str, torch.device]
|
device : Union[str, torch.device]
|
||||||
the PyTorch device to put the model into
|
the PyTorch device to put the model into
|
||||||
download_root: str
|
download_root: str
|
||||||
@@ -350,16 +498,51 @@ def load_model(
|
|||||||
model : Whisper
|
model : Whisper
|
||||||
The Whisper ASR model instance
|
The Whisper ASR model instance
|
||||||
"""
|
"""
|
||||||
|
from whisperlivekit.model_paths import detect_model_format
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
if download_root is None:
|
if download_root is None:
|
||||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|
||||||
|
checkpoint = None
|
||||||
|
model_path_for_config = name # Used to find config.json for dims inference
|
||||||
|
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
|
if in_memory:
|
||||||
|
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
|
||||||
|
else:
|
||||||
|
checkpoint = _load_checkpoint(checkpoint_file, device)
|
||||||
elif os.path.isfile(name):
|
elif os.path.isfile(name):
|
||||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
if in_memory:
|
||||||
|
with open(name, "rb") as f:
|
||||||
|
checkpoint_bytes = f.read()
|
||||||
|
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
|
||||||
|
else:
|
||||||
|
checkpoint = _load_checkpoint(name, device)
|
||||||
|
model_path_for_config = name
|
||||||
|
elif os.path.isdir(name):
|
||||||
|
model_info = detect_model_format(name)
|
||||||
|
|
||||||
|
if not model_info.has_pytorch:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No PyTorch checkpoint found in directory {name}. "
|
||||||
|
f"Expected .pt, .bin, or .safetensors file(s)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info.is_sharded:
|
||||||
|
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
|
||||||
|
else:
|
||||||
|
single_file = model_info.pytorch_files[0]
|
||||||
|
if in_memory:
|
||||||
|
with open(single_file, "rb") as f:
|
||||||
|
checkpoint_bytes = f.read()
|
||||||
|
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
|
||||||
|
else:
|
||||||
|
checkpoint = _load_checkpoint(single_file, device)
|
||||||
|
model_path_for_config = name
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Model {name} not found; available models = {available_models()}"
|
f"Model {name} not found; available models = {available_models()}"
|
||||||
@@ -369,34 +552,23 @@ def load_model(
|
|||||||
if custom_alignment_heads:
|
if custom_alignment_heads:
|
||||||
alignment_heads = custom_alignment_heads.encode()
|
alignment_heads = custom_alignment_heads.encode()
|
||||||
|
|
||||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
|
||||||
try:
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
|
||||||
if in_memory:
|
|
||||||
checkpoint = load_file(checkpoint_file, device=device)
|
|
||||||
else:
|
|
||||||
checkpoint = load_file(checkpoint_file, device=device)
|
|
||||||
else:
|
|
||||||
with (
|
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
|
||||||
) as fp:
|
|
||||||
checkpoint = torch.load(fp, map_location=device)
|
|
||||||
del checkpoint_file
|
|
||||||
|
|
||||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||||
state_dict = checkpoint["model_state_dict"]
|
state_dict = checkpoint["model_state_dict"]
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
if alignment_heads is None and "alignment_heads" in state_dict:
|
||||||
|
alignment_heads = state_dict["alignment_heads"]
|
||||||
|
|
||||||
state_dict = _convert_hf_state_dict(state_dict)
|
state_dict = _convert_hf_state_dict(state_dict)
|
||||||
|
state_dict = _convert_mlx_state_dict(state_dict)
|
||||||
_apply_lora_adapter(state_dict, lora_path)
|
_apply_lora_adapter(state_dict, lora_path)
|
||||||
|
|
||||||
if dims_cfg is not None:
|
if dims_cfg is not None:
|
||||||
dims = ModelDimensions(**dims_cfg)
|
dims = ModelDimensions(**dims_cfg)
|
||||||
else:
|
else:
|
||||||
dims = _infer_dims_from_config(name)
|
dims = _infer_dims_from_config(model_path_for_config)
|
||||||
if dims is None:
|
if dims is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Could not determine model dimensions. "
|
"Could not determine model dimensions. "
|
||||||
@@ -416,8 +588,13 @@ def load_model(
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
|
if isinstance(alignment_heads, bytes):
|
||||||
model.set_alignment_heads(alignment_heads)
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
|
||||||
|
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
|
||||||
|
for layer, head in alignment_heads.tolist():
|
||||||
|
mask[layer, head] = True
|
||||||
|
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -296,10 +296,15 @@ class Tokenizer:
|
|||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
|
||||||
if (
|
try:
|
||||||
replacement_char not in decoded
|
replacement_char_index = decoded.index(replacement_char)
|
||||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
replacement_char_index += unicode_offset
|
||||||
== replacement_char
|
except ValueError:
|
||||||
|
replacement_char_index = None
|
||||||
|
|
||||||
|
if replacement_char_index is None or (
|
||||||
|
replacement_char_index < len(decoded_full)
|
||||||
|
and decoded_full[replacement_char_index] == replacement_char
|
||||||
):
|
):
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
|
|||||||
200
whisperlivekit/whisper/val.py
Normal file
200
whisperlivekit/whisper/val.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
The most atomic way to train and inference a GPT in pure, dependency-free Python.
|
||||||
|
This file is the complete algorithm.
|
||||||
|
Everything else is just efficiency.
|
||||||
|
|
||||||
|
@karpathy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os # os.path.exists
|
||||||
|
import math # math.log, math.exp
|
||||||
|
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||||
|
random.seed(42) # Let there be order among chaos
|
||||||
|
|
||||||
|
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
|
||||||
|
if not os.path.exists('input.txt'):
|
||||||
|
import urllib.request
|
||||||
|
names_url = 'https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt'
|
||||||
|
urllib.request.urlretrieve(names_url, 'input.txt')
|
||||||
|
docs = [l.strip() for l in open('input.txt').read().strip().split('\n') if l.strip()] # list[str] of documents
|
||||||
|
random.shuffle(docs)
|
||||||
|
print(f"num docs: {len(docs)}")
|
||||||
|
|
||||||
|
# Let there be a Tokenizer to translate strings to discrete symbols and back
|
||||||
|
uchars = sorted(set(''.join(docs))) # unique characters in the dataset become token ids 0..n-1
|
||||||
|
BOS = len(uchars) # token id for the special Beginning of Sequence (BOS) token
|
||||||
|
vocab_size = len(uchars) + 1 # total number of unique tokens, +1 is for BOS
|
||||||
|
print(f"vocab size: {vocab_size}")
|
||||||
|
|
||||||
|
# Let there be Autograd, to recursively apply the chain rule through a computation graph
|
||||||
|
class Value:
|
||||||
|
__slots__ = ('data', 'grad', '_children', '_local_grads') # Python optimization for memory usage
|
||||||
|
|
||||||
|
def __init__(self, data, children=(), local_grads=()):
|
||||||
|
self.data = data # scalar value of this node calculated during forward pass
|
||||||
|
self.grad = 0 # derivative of the loss w.r.t. this node, calculated in backward pass
|
||||||
|
self._children = children # children of this node in the computation graph
|
||||||
|
self._local_grads = local_grads # local derivative of this node w.r.t. its children
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
other = other if isinstance(other, Value) else Value(other)
|
||||||
|
return Value(self.data + other.data, (self, other), (1, 1))
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
other = other if isinstance(other, Value) else Value(other)
|
||||||
|
return Value(self.data * other.data, (self, other), (other.data, self.data))
|
||||||
|
|
||||||
|
def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),))
|
||||||
|
def log(self): return Value(math.log(self.data), (self,), (1/self.data,))
|
||||||
|
def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),))
|
||||||
|
def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),))
|
||||||
|
def __neg__(self): return self * -1
|
||||||
|
def __radd__(self, other): return self + other
|
||||||
|
def __sub__(self, other): return self + (-other)
|
||||||
|
def __rsub__(self, other): return other + (-self)
|
||||||
|
def __rmul__(self, other): return self * other
|
||||||
|
def __truediv__(self, other): return self * other**-1
|
||||||
|
def __rtruediv__(self, other): return other * self**-1
|
||||||
|
|
||||||
|
def backward(self):
|
||||||
|
topo = []
|
||||||
|
visited = set()
|
||||||
|
def build_topo(v):
|
||||||
|
if v not in visited:
|
||||||
|
visited.add(v)
|
||||||
|
for child in v._children:
|
||||||
|
build_topo(child)
|
||||||
|
topo.append(v)
|
||||||
|
build_topo(self)
|
||||||
|
self.grad = 1
|
||||||
|
for v in reversed(topo):
|
||||||
|
for child, local_grad in zip(v._children, v._local_grads):
|
||||||
|
child.grad += local_grad * v.grad
|
||||||
|
|
||||||
|
# Initialize the parameters, to store the knowledge of the model.
|
||||||
|
n_embd = 16 # embedding dimension
|
||||||
|
n_head = 4 # number of attention heads
|
||||||
|
n_layer = 1 # number of layers
|
||||||
|
block_size = 16 # maximum sequence length
|
||||||
|
head_dim = n_embd // n_head # dimension of each head
|
||||||
|
matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]
|
||||||
|
state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)}
|
||||||
|
for i in range(n_layer):
|
||||||
|
state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
|
||||||
|
state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
|
||||||
|
state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
|
||||||
|
state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
|
||||||
|
state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
|
||||||
|
state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)
|
||||||
|
params = [p for mat in state_dict.values() for row in mat for p in row] # flatten params into a single list[Value]
|
||||||
|
print(f"num params: {len(params)}")
|
||||||
|
# Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next.
|
||||||
|
# Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU
|
||||||
|
|
||||||
|
def linear(x, w):
|
||||||
|
return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(logits):
|
||||||
|
max_val = max(val.data for val in logits)
|
||||||
|
exps = [(val - max_val).exp() for val in logits]
|
||||||
|
total = sum(exps)
|
||||||
|
return [e / total for e in exps]
|
||||||
|
|
||||||
|
def rmsnorm(x):
|
||||||
|
ms = sum(xi * xi for xi in x) / len(x)
|
||||||
|
scale = (ms + 1e-5) ** -0.5
|
||||||
|
return [xi * scale for xi in x]
|
||||||
|
|
||||||
|
def gpt(token_id, pos_id, keys, values):
|
||||||
|
tok_emb = state_dict['wte'][token_id] # token embedding
|
||||||
|
pos_emb = state_dict['wpe'][pos_id] # position embedding
|
||||||
|
x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding
|
||||||
|
x = rmsnorm(x)
|
||||||
|
|
||||||
|
for li in range(n_layer):
|
||||||
|
# 1) Multi-head attention block
|
||||||
|
x_residual = x
|
||||||
|
x = rmsnorm(x)
|
||||||
|
q = linear(x, state_dict[f'layer{li}.attn_wq'])
|
||||||
|
k = linear(x, state_dict[f'layer{li}.attn_wk'])
|
||||||
|
v = linear(x, state_dict[f'layer{li}.attn_wv'])
|
||||||
|
keys[li].append(k)
|
||||||
|
values[li].append(v)
|
||||||
|
x_attn = []
|
||||||
|
for h in range(n_head):
|
||||||
|
hs = h * head_dim
|
||||||
|
q_h = q[hs:hs+head_dim]
|
||||||
|
k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
|
||||||
|
v_h = [vi[hs:hs+head_dim] for vi in values[li]]
|
||||||
|
attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
|
||||||
|
attn_weights = softmax(attn_logits)
|
||||||
|
head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
|
||||||
|
x_attn.extend(head_out)
|
||||||
|
x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
|
||||||
|
x = [a + b for a, b in zip(x, x_residual)]
|
||||||
|
# 2) MLP block
|
||||||
|
x_residual = x
|
||||||
|
x = rmsnorm(x)
|
||||||
|
x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
|
||||||
|
x = [xi.relu() for xi in x]
|
||||||
|
x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
|
||||||
|
x = [a + b for a, b in zip(x, x_residual)]
|
||||||
|
|
||||||
|
logits = linear(x, state_dict['lm_head'])
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Let there be Adam, the blessed optimizer and its buffers
|
||||||
|
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
|
||||||
|
m = [0.0] * len(params) # first moment buffer
|
||||||
|
v = [0.0] * len(params) # second moment buffer
|
||||||
|
# Repeat in sequence
|
||||||
|
num_steps = 1000 # number of training steps
|
||||||
|
for step in range(num_steps):
|
||||||
|
|
||||||
|
# Take single document, tokenize it, surround it with BOS special token on both sides
|
||||||
|
doc = docs[step % len(docs)]
|
||||||
|
tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
|
||||||
|
n = min(block_size, len(tokens) - 1)
|
||||||
|
|
||||||
|
# Forward the token sequence through the model, building up the computation graph all the way to the loss.
|
||||||
|
keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
|
||||||
|
losses = []
|
||||||
|
for pos_id in range(n):
|
||||||
|
token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
|
||||||
|
logits = gpt(token_id, pos_id, keys, values)
|
||||||
|
probs = softmax(logits)
|
||||||
|
loss_t = -probs[target_id].log()
|
||||||
|
losses.append(loss_t)
|
||||||
|
loss = (1 / n) * sum(losses) # final average loss over the document sequence. May yours be low.
|
||||||
|
|
||||||
|
# Backward the loss, calculating the gradients with respect to all model parameters.
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Adam optimizer update: update the model parameters based on the corresponding gradients.
|
||||||
|
lr_t = learning_rate * (1 - step / num_steps) # linear learning rate decay
|
||||||
|
for i, p in enumerate(params):
|
||||||
|
m[i] = beta1 * m[i] + (1 - beta1) * p.grad
|
||||||
|
v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
|
||||||
|
m_hat = m[i] / (1 - beta1 ** (step + 1))
|
||||||
|
v_hat = v[i] / (1 - beta2 ** (step + 1))
|
||||||
|
p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
|
||||||
|
p.grad = 0
|
||||||
|
|
||||||
|
print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")
|
||||||
|
|
||||||
|
# Inference: may the model babble back to us
|
||||||
|
temperature = 0.5 # in (0, 1], control the "creativity" of generated text, low to high
|
||||||
|
print("\n--- inference (new, hallucinated names) ---")
|
||||||
|
for sample_idx in range(20):
|
||||||
|
keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
|
||||||
|
token_id = BOS
|
||||||
|
sample = []
|
||||||
|
for pos_id in range(block_size):
|
||||||
|
logits = gpt(token_id, pos_id, keys, values)
|
||||||
|
probs = softmax([l / temperature for l in logits])
|
||||||
|
token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0]
|
||||||
|
if token_id == BOS:
|
||||||
|
break
|
||||||
|
sample.append(uchars[token_id])
|
||||||
|
print(f"sample {sample_idx+1:2d}: {''.join(sample)}")
|
||||||
Reference in New Issue
Block a user