118 Commits
0.2.13 ... main

Author SHA1 Message Date
Quentin Fuxa
8bc0937c46 Update README section on powered research 2026-03-06 18:46:07 +01:00
Quentin Fuxa
929cf7a26b add link to AlignAtt interactive playground 2026-03-06 18:43:25 +01:00
Quentin Fuxa
abfaf06203 Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-03-04 18:17:23 +01:00
Quentin Fuxa
d1fe932241 Apply DRY method v0 - to try to catch and resolve infinite loops such as in #338 2026-03-03 22:52:00 +01:00
Quentin Fuxa
c112ceffb6 Merge pull request #342 from mnicnc404/fix/whisper-tokenizer-index-error
fix(whisper/tokenizer): prevent IndexError from crashing multilingual…
2026-03-02 20:36:58 +01:00
Quentin Fuxa
4917406e06 Merge pull request #341 from AymurAI/feat/uv-deps-resolution
deps/docker: align python support, deterministic deps resolution & docker images releases
2026-03-02 20:34:49 +01:00
Chingning Chen
b63f54e838 fix(whisper/tokenizer): prevent IndexError from crashing multilingual streams
This fix addresses a critical bug in the Whisper tokenizer that causes
the transcription server to crash with an `IndexError: string index out
of range` when streaming audio in languages utilizing multi-byte UTF-8
characters (e.g., Cantonese, Japanese, Mandarin).

When a 3-byte character is cut off at the boundary of an audio chunk,
incomplete bytes are decoded into a single Unicode replacement character
(`\ufffd`), artificially shortening the string and breaking the offset
mapping assumed by `split_tokens_on_unicode`.

This ports the upstream fix from SYSTRAN/faster-whisper (PR #111) to add
a strict bounds check before accessing the string index, allowing
incomplete bytes to be safely caught and handled in the next chunk.
2026-03-02 15:31:43 +08:00
jedzill4
c56a53fbf4 deps(mlx-groups): add optional dependencies for Apple Silicon MLX backends 2026-03-01 20:05:52 -03:00
Quentin Fuxa
66e58624b9 disable MLXAlignAtt which fails on special characters 2026-03-01 11:52:00 +01:00
jedzill4
9366e067f9 deps(pyproject): add torch and torchaudio to main dependencies 2026-02-27 19:19:18 -03:00
jedzill4
866c25670c deps(docker): change CUDA base image to runtime version 2026-02-27 19:16:29 -03:00
jedzill4
2553ef283e deps(docker): fix dependency group for cu129 image
- Changed the extras for cu129-diarization-sortformer from gpu-cu129 to cu129.
- This aligns the dependency with the correct naming convention for consistency.
2026-02-25 21:49:08 -03:00
jedzill4
73e7fafc48 feat(tests): python matrix support test
- Introduced a new argument for selecting the diarization backend in the engine creation.
- Enhanced the `create_engine` function to accept and utilize the specified diarization backend.
- Updated the test runner to accommodate the new backend option for improved flexibility.
2026-02-25 21:35:41 -03:00
jedzill4
bbcebcb1fe deps(sortformer): adjust nemo-toolkit version constraints
- Updated the version constraint for `diarization-sortformer` to restrict it to Python 3.10 and below.
2026-02-25 21:33:00 -03:00
jedzill4
4bb58dc7aa deps(diart): improve diart dependency tree. rename gpu-cu129 dependency group to cu129 2026-02-25 20:27:26 -03:00
jedzill4
27ca028479 ci(github): add GitHub Actions workflows for Docker image publishing and support matrix
- Introduced a workflow to publish Docker images on tag push and manual triggers.
- Added a support matrix workflow to test across multiple OS and Python versions.
2026-02-25 14:27:51 -03:00
jedzill4
d24805cc18 🚀 chore (docker): update docker images improving caching and using uv as python package manager 2026-02-25 14:22:43 -03:00
jedzill4
994ce21365 📌 chore(deps): pin dependences to python 3.11 to 3.13 due dependency resolution matrix 2026-02-25 14:21:19 -03:00
jedzill4
132823dc09 deps: improve deps dependency resolution (wip) 2026-02-24 20:15:53 -03:00
jedzill4
d6d8c2635f chore: use uv as python project manager to improve dependency resolution 2026-02-23 22:16:32 -03:00
Quentin Fuxa
8fedeb9fed Merge pull request #340 from QuentinFuxa/voxtral_tests
feat: voxtral-mlx backend, benchmark suite, unit tests, runtime metrics
2026-02-23 10:37:40 +01:00
Quentin Fuxa
b1fc23807a docs: add benchmark collaboration call, voxtral in powered-by section 2026-02-23 10:37:22 +01:00
Quentin Fuxa
10c4e5f730 docs: add speed vs accuracy scatter plot to benchmark and README
WER vs RTF scatter plot showing all backend/policy/model combos
on the 30s English file. Sweet spot zone highlights the best
tradeoffs. Added to both BENCHMARK.md and README.md.
2026-02-23 10:27:53 +01:00
Quentin Fuxa
c76b2ef2c6 docs: rewrite benchmark with base/small comparison, proper French results
- Re-ran all whisper benchmarks with --lan fr for the French file
  (previously ran with --lan en which made the results meaningless)
- Added small model results alongside base for all backends
- Added model size comparison table (base vs small tradeoffs)
- Added benchmark chart (30s English, WER + RTF by backend)
- Added caveats section about dataset size and RTF variance
- Key findings: SimulStreaming saturates at 5.3% WER on base already,
  small model mainly helps LocalAgreement and French timestamps
- mlx-whisper LA base is unstable on French (hallucination loops)
2026-02-23 10:16:34 +01:00
Quentin Fuxa
4b2377c243 fix: correct false auto-detect claim, median bug, RTF inflation
- BENCHMARK.md: whisper also supports --language auto, voxtral is not
  the only one. Fixed mlx-whisper speed comparison (LA is actually
  faster than SS for mlx-whisper, not comparable).
- metrics.py: median calculation was wrong for even-length lists
  (took upper middle instead of averaging the two middle values).
- metrics_collector.py: RTF was inflated because log_summary() used
  wall-clock elapsed time instead of sum of actual ASR call durations.
- README.md: clarified that whisper also supports auto language
  detection, voxtral just does it better.
- Added 2 new median tests (even + odd length).
2026-02-22 23:38:04 +01:00
Quentin Fuxa
a4da246ea5 feat: add voxtral-mlx native backend for Apple Silicon
Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency
speech transcription on Apple Silicon. Avoids the transformers/torch
overhead and runs at 0.18-0.32x real-time factor.

- voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder
- voxtral_mlx/loader.py: model loading with 6-bit quantized weights
- voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX
- voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
2026-02-22 23:28:10 +01:00
Quentin Fuxa
9b2c3ee844 docs: update README with voxtral backend, benchmarks, testing sections
- Add Voxtral Backend section explaining voxtral-mlx and voxtral (HF).
- Add Testing & Benchmarks section with commands to run tests/benchmarks.
- Update --backend parameter docs to include voxtral-mlx and voxtral.
- Update optional dependencies table with Voxtral entry.
- Link to BENCHMARK.md for detailed performance comparisons.
2026-02-22 23:27:57 +01:00
Quentin Fuxa
83d0fa3fac feat: benchmark suite with WER, timestamp accuracy, cross-backend comparison
- Extend test_backend_offline.py with WER and timestamp accuracy metrics
  computed via whisperlivekit.metrics against ground truth transcripts.
- Add --benchmark flag to auto-detect all installed backends and run
  each (backend, policy) combination in sequence.
- Add --policy flag to override the streaming policy.
- Add detect_available_backends() probing faster-whisper, mlx-whisper,
  voxtral-mlx, voxtral (HF), and openai-whisper.
- Add print_cross_backend_comparison() with per-combo averages.
- Add run_benchmark.py for comprehensive multi-model benchmarking.
- Add BENCHMARK.md with full results on Apple M4: speed, WER,
  timestamp accuracy, VAC impact, and recommendations.
- Add ground truth transcript JSON files for all audio test files.
2026-02-22 23:27:50 +01:00
Quentin Fuxa
5a12c627b4 feat: add 99-test unit test suite with zero model dependencies
Test suite covering:
- metrics.py: WER computation, timestamp accuracy, text normalization
- config.py: defaults, .en model detection, policy aliases, from_namespace
- timed_objects.py: ASRToken, Silence, Transcript, Segment, FrontData
- hypothesis_buffer.py: insert, flush, LCP matching, pop_committed
- silence_handling.py: state machine, double-counting regression test
- audio_processor.py: async pipeline with MockOnlineProcessor

All tests run in ~1.3s without downloading any ASR models.
Add pytest and pytest-asyncio as optional test dependencies.
Update .gitignore to allow tests/ directory.
2026-02-22 23:27:40 +01:00
Quentin Fuxa
f5eee67b11 fix: silence double-counting bug, add metrics module and runtime instrumentation
- Fix _begin_silence pushing same object reference as _end_silence,
  causing the consumer to process two ended events and double the
  silence duration.
- Fix initial silence never cleared when VAC is disabled, causing
  the no-VAC path to enqueue zero audio.
- Add sample-precise silence boundaries (at_sample parameter).
- Add whisperlivekit/metrics.py with WER computation (word-level
  Levenshtein) and timestamp accuracy (greedy alignment). No
  external dependencies.
- Add whisperlivekit/metrics_collector.py with SessionMetrics
  dataclass for per-session runtime observability. Instrumented
  at 6 points in AudioProcessor: init, process_audio,
  transcription_processor, _end_silence, results_formatter, cleanup.
  Emits SESSION_METRICS structured log line on session end.
2026-02-22 23:27:12 +01:00
Quentin Fuxa
4a6868e3e1 correct processor attributes mixtral 2026-02-22 21:13:21 +01:00
Quentin Fuxa
3c15246fc0 mixstral hf v0 2026-02-20 20:49:57 +01:00
Quentin Fuxa
d337248fda feat: add healthcheck to Dockerfiles (#228) 2026-02-20 20:48:28 +01:00
Quentin Fuxa
b8d9d7d289 fix: handle numpy object_ dtype from ctranslate2 encoder (#337) 2026-02-20 20:48:28 +01:00
Quentin Fuxa
4c7706e2cf fix: use vac_chunk_size for audio processing interval when VAC is enabled (#334) 2026-02-20 20:48:06 +01:00
Quentin Fuxa
7f3a3df620 simulstreaming mlx & torch dedup of common base 2025-02-15 23:52:00 +01:00
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
Quentin Fuxa
b67ddea494 bump to 0.2.17 2025-12-08 23:52:00 +01:00
Quentin Fuxa
3192553e20 fixes #307 2025-12-09 10:27:49 +01:00
Quentin Fuxa
f379a243fe Merge pull request #274 from blakkd/patch-1
minor path change
2025-12-09 10:10:32 +01:00
Quentin Fuxa
ec09898a9f fixes #301 2025-12-06 10:19:50 +01:00
blakkd
befbae56c7 minor path change
prevents

```
FileNotFoundError: [Errno 2] No such file or directory: 'whisperlivekit/web/live_transcription.html'
```
2025-11-16 23:47:58 +01:00
Quentin Fuxa
bbd4fd6cff Merge branch 'improve_EOS_handling' 2025-11-16 22:30:31 +01:00
Quentin Fuxa
28985962a0 Silence handling: finish transcription even if not validated at the BEGINNING of the silence 2025-11-16 22:29:08 +01:00
Quentin Fuxa
a38c103fcd simulstreaming coreml encoder compatibility 2025-11-16 21:24:14 +01:00
Quentin Fuxa
4d2ffb24f8 coreml conversion 2025-11-16 19:11:43 +01:00
Quentin Fuxa
1bbbb7903c lora loader in shared whisper core 2025-11-16 18:44:35 +01:00
Quentin Fuxa
bcffdbc6b3 bump to 0.2.14 2025-11-15 20:19:09 +01:00
Quentin Fuxa
80b77998f9 Refactor backend handling 2025-11-15 19:51:41 +01:00
Quentin Fuxa
d310f7e25f hf compatibility 2025-11-15 18:34:19 +01:00
Quentin Fuxa
8d9be88fe6 translation buffer is now displayed in frontend 2025-11-10 15:22:26 +01:00
Quentin Fuxa
16461052ed task to direct-english-translation 2025-11-10 13:20:26 +01:00
Quentin Fuxa
5491dbd824 last_validated_token handled in state 2025-11-10 13:18:52 +01:00
Quentin Fuxa
13401ffe24 whisper core at root of wlk 2025-11-10 12:17:18 +01:00
Quentin Fuxa
7108d2ddc5 fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/269 2025-11-09 20:08:18 +01:00
Quentin Fuxa
a732e0903e Add a script to detect alignement heads, usefull for distilled whisper 2025-11-09 18:12:09 +01:00
Quentin Fuxa
0491681be4 Distilled model compatibility with HF config.json to ModelDimensions 2025-11-08 20:20:05 +01:00
Quentin Fuxa
ffe5284764 _processing_tasks_done checks task completion 2025-11-05 23:34:00 +01:00
Quentin Fuxa
719e8b1a20 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
f1b47178d8 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
59db08e961 loader for full mlx 2024-11-25 23:52:00 +01:00
Quentin Fuxa
6fc20b9562 new dec class 2024-11-21 23:52:00 +01:00
Quentin Fuxa
fac8659161 uses native mlx function for attention 2024-11-21 23:52:00 +01:00
Quentin Fuxa
4d9332ce7d fixes #299 2025-12-05 17:54:14 +01:00
Quentin Fuxa
62444ce746 session parameter required in OnnxWrapper 2025-12-05 15:37:18 +01:00
Quentin Fuxa
2431a6bf91 isolated VAD states per user: .onnx: share a stateless model. .jit: require duplicating the model.
Co-authored-by: eschmidbauer <eschmidbauer@gmail.com>
2025-12-05 15:27:14 +01:00
Quentin Fuxa
d1263e7228 Merge pull request #308 from gzz2000/main
Fix local agreement backend, removing excess parameter, #295
2025-12-05 11:34:05 +01:00
Zizheng Guo
30ddd522a4 Fix local agreement backend, removing excess parameter, fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/295 2025-12-04 16:45:23 +08:00
Quentin Fuxa
635bace09e update archi 2025-11-30 18:39:10 +01:00
Quentin Fuxa
f1113e3eb0 update with LoRA 2025-11-29 18:33:30 +01:00
Quentin Fuxa
cc5f819ce7 hf weights 2025-11-29 17:50:46 +01:00
Quentin Fuxa
82cd24bb75 LoRa path v0 - functional 2025-11-29 17:21:10 +01:00
Quentin Fuxa
d45c397c6a simulstreaming: limit n tokens to prevent hallucinations 2025-11-28 21:41:19 +01:00
Quentin Fuxa
45bf3f57d7 troubleshooting doc for aarch64 systems 2025-11-28 21:40:43 +01:00
Quentin Fuxa
1d88ba9d69 Fixes #294. improve model path backend detection and file extraction 2025-11-27 23:14:00 +01:00
Quentin Fuxa
c0965c6c31 Lines to Segments. Merging dataclasses 2025-11-27 21:54:58 +01:00
Quentin Fuxa
34ddd2ac02 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
345d781e97 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
28cf831701 indicate for context token limits for --max-context-tokens. bump to 0.2.16.dev0 2025-11-25 23:45:15 +01:00
Quentin Fuxa
60c62f8f84 troubleshooting #271 #276 #284 #286 2025-11-25 23:31:46 +01:00
Quentin Fuxa
7faa21f95f alignatt: enable model sharing by removing hooks and centralizing session state. Solves #282
Co-authored-by: Emmanuel Schmidbauer <eschmidbauer@gmail.com>
2025-11-25 23:07:42 +01:00
Quentin Fuxa
4e9f951551 correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
870141298c isort 2025-11-23 11:20:00 +01:00
Quentin Fuxa
872faa422a correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
fc9cb66813 disabling vac is not advised 2025-11-23 11:20:00 +01:00
Quentin Fuxa
a175d1a327 fixes silence detected but never reported by silero 2025-11-23 11:20:00 +01:00
Quentin Fuxa
6206fff118 0.2.15 2025-11-21 23:52:00 +01:00
Quentin Fuxa
b5067249c0 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
f4f9831d39 stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
254faaf64c stt/diar/nllw alignment: internal rework 5 2025-11-20 23:52:00 +01:00
Quentin Fuxa
8e7aea4fcf internal rework 4 2025-11-20 23:45:20 +01:00
Quentin Fuxa
270faf2069 internal rework 3 2025-11-20 22:28:30 +01:00
Quentin Fuxa
b7c1cc77cc internal rework 2 2025-11-20 22:06:38 +01:00
Quentin Fuxa
9a45ec221c internal rework 1 2025-11-20 12:58:38 +01:00
Quentin Fuxa
3e13ee6fc3 bump to post4 2025-11-19 21:23:43 +01:00
Quentin Fuxa
b7d20a0ff0 segment attribution in result formatter 2025-11-19 21:10:28 +01:00
Quentin Fuxa
c1bb9c2bde reduce flickering remaining_time_transcription 2025-11-19 19:09:37 +01:00
Quentin Fuxa
11e9def0b2 diarization corrections 2025-11-19 19:06:03 +01:00
Quentin Fuxa
3104f40f6e fixes #279 #278 2025-11-19 18:17:50 +01:00
Quentin Fuxa
e9b4ceeee5 Add audio partial silence in chunks handling. bump to 0.2.14.post3 2025-11-17 22:52:00 +01:00
Quentin Fuxa
437641fb43 reduce min-chunk-size to 0.1, set default model to base 2027-04-25 23:52:00 +02:00
Quentin Fuxa
bfd60b3921 Add audio partial silence in chunks handling. bump to 0.2.14.post2 2025-11-17 22:52:00 +01:00
Quentin Fuxa
1e67bf97f0 improve buffering when use of heavy models 2027-04-25 23:52:00 +02:00
118 changed files with 18678 additions and 3221 deletions

13
.dockerignore Normal file
View File

@@ -0,0 +1,13 @@
.git
.github
.venv
__pycache__
*.pyc
.pytest_cache
.mypy_cache
.ruff_cache
.cache
.tmp
.secrets
dist
build

61
.github/workflows/publish-docker.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: Publish Docker Images
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
tag:
description: "Image tag to publish (without image suffix)"
required: true
type: string
permissions:
contents: read
packages: write
jobs:
docker:
runs-on: ubuntu-latest
env:
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
strategy:
fail-fast: false
matrix:
include:
- image_suffix: cpu-diarization-sortformer
dockerfile: Dockerfile.cpu
extras: cpu,diarization-sortformer
- image_suffix: cu129-diarization-sortformer
dockerfile: Dockerfile
extras: cu129,diarization-sortformer
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set lowercase owner
id: owner
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
file: ./${{ matrix.dockerfile }}
push: true
build-args: |
EXTRAS=${{ matrix.extras }}
tags: |
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}

6
.gitignore vendored
View File

@@ -119,9 +119,11 @@ run_*.sh
*.pt *.pt
# Debug & testing # Debug & testing
test_*.py /test_*.py
!test_backend_offline.py
launch.json launch.json
.DS_Store .DS_Store
test/* /test/
!tests/
nllb-200-distilled-600M-ctranslate2/* nllb-200-distilled-600M-ctranslate2/*
*.mp3 *.mp3

205
BENCHMARK.md Normal file
View File

@@ -0,0 +1,205 @@
# WhisperLiveKit Benchmark Report
Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon.
All tests run through the full AudioProcessor pipeline (same code path as production WebSocket).
## Test Environment
| Property | Value |
|----------|-------|
| Hardware | Apple M4, 32 GB RAM |
| OS | macOS 25.3.0 (arm64) |
| Python | 3.13 |
| faster-whisper | 1.2.1 |
| mlx-whisper | installed (via mlx) |
| Voxtral MLX | native MLX backend |
| Voxtral (HF) | transformers-based |
| VAC (Silero VAD) | enabled unless noted |
| Chunk size | 100 ms |
| Pacing | no-realtime (as fast as possible) |
## Audio Test Files
| File | Duration | Language | Speakers | Description |
|------|----------|----------|----------|-------------|
| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses |
| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps |
| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation |
Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified.
---
## Results
### English -- Short (7.2 s, 1 speaker)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s |
| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s |
| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s |
| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s |
| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s |
| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s |
| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s |
| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s |
| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s |
| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s |
### English -- Multi-speaker (30.0 s, 3 speakers)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s |
| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s |
| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s |
| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s |
| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s |
| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s |
| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s |
| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s |
| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s |
| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s |
<p align="center">
<img src="benchmark_chart.png" alt="Benchmark comparison on 30s English" width="800">
</p>
<p align="center">
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
</p>
### French (16.3 s, 1 speaker, `--language fr`)
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|---------|--------|-------|-----|-----|---------------|
| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s |
| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s |
| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s |
| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s |
| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* |
| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s |
| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s |
| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s |
| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s |
| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s |
\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem.
**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps.
---
## Model Size Comparison (base vs small)
| | base | small | Observation |
|--|------|-------|-------------|
| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower |
| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base |
| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio |
| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo |
| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps |
In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages.
---
## Key Findings
### Speed (RTF = processing time / audio duration, lower is better)
1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds.
2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed.
3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time.
4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead.
5. The **small** model is 2-3x slower than base across all backends.
### Accuracy (WER = Word Error Rate, lower is better)
1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%.
2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments.
3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run.
4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER.
### Timestamps (MAE = Mean Absolute Error on word start times)
1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE).
2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications.
3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`.
4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file).
### VAC (Voice Activity Classification) Impact
| Backend | Policy | VAC | 7s English WER | 30s English WER |
|---------|--------|-----|----------------|-----------------|
| faster-whisper | LocalAgreement | on | 21.1% | 44.7% |
| faster-whisper | LocalAgreement | off | 100.0% | 100.0% |
| voxtral-mlx | voxtral | on | 0.0% | 9.2% |
| voxtral-mlx | voxtral | off | 0.0% | 9.2% |
- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output.
- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments.
---
## Recommendations
| Use Case | Backend | Policy | Model | Notes |
|----------|---------|--------|-------|-------|
| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER |
| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER |
| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast |
| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF |
| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles |
| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response |
---
## Caveats
- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions.
- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine.
- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU.
---
## Reproducing These Benchmarks
```bash
# Install test dependencies
pip install -e ".[test]"
# Single backend test
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime
# With a specific language
python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime
# Multi-backend auto-detect benchmark
python test_backend_offline.py --benchmark --no-realtime
# Export to JSON
python test_backend_offline.py --benchmark --no-realtime --json results.json
# Test with your own audio
python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime
```
The benchmark harness computes WER and timestamp accuracy automatically when ground truth
`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format.
---
## Help Us Benchmark on More Hardware
These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc.
If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get.
What we are especially interested in:
- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper
- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx
- **Medium and large-v3 models** (we only tested base and small so far)
- **Longer audio files** or domain-specific audio (medical, legal, call center)
- **Other languages** beyond English and French

View File

@@ -1,83 +1,75 @@
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
WORKDIR /app WORKDIR /app
ARG EXTRAS RUN apt-get update && \
ARG HF_PRECACHE_DIR apt-get install -y --no-install-recommends \
ARG HF_TKN_FILE build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install UV and set up the environment
COPY --from=uvbin /uv /uvx /bin/
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
RUN uv python install 3.12
# Install dependencies first to leverage caching
ARG EXTRAS=cu129
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Copy the source code and install the package only
COPY whisperlivekit /app/whisperlivekit
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-editable --no-cache "$@"
# --- MARK: Runtime Stage
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
python3 \ ffmpeg &&\
python3-pip \ rm -rf /var/lib/apt/lists/*
python3-venv \
ffmpeg \
git \
build-essential \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv # Copy UV binaries
ENV PATH="/opt/venv/bin:$PATH" COPY --from=uvbin /uv /uvx /bin/
# timeout/retries for large torch wheels # Copy the Python version
RUN pip3 install --upgrade pip setuptools wheel && \ COPY --from=builder-gpu --chown=python:python /python /python
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . . # Copy the virtual environment with all dependencies installed
COPY --from=builder-gpu /app/.venv /app/.venv
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# In-container caching for Hugging Face models by:
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
# WARNING: This will copy ALL files in the pre-cache location.
# Conditionally copy a cache directory if provided
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
EXPOSE 8000 EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
CMD ["--model", "medium"] CMD ["--model", "medium"]

View File

@@ -1,61 +1,76 @@
FROM python:3.13-slim FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
# --- MARK: Builder Stage
FROM debian:bookworm-slim AS builder-cpu
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
WORKDIR /app WORKDIR /app
ARG EXTRAS RUN apt-get update && \
ARG HF_PRECACHE_DIR apt-get install -y --no-install-recommends \
ARG HF_TKN_FILE build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install UV and set up the environment
COPY --from=uvbin /uv /uvx /bin/
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
ENV UV_PYTHON_PREFERENCE=only-managed
ENV UV_PYTHON_INSTALL_DIR=/python
RUN uv python install 3.12
# Install dependencies first to leverage caching
ARG EXTRAS=cpu
COPY pyproject.toml uv.lock /app/
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
# Copy the source code and install the package only
COPY whisperlivekit /app/whisperlivekit
RUN set -eux; \
set --; \
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
set -- "$@" --extra "$extra"; \
done; \
uv sync --frozen --no-editable --no-cache "$@"
# --- MARK: Runtime Stage
FROM debian:bookworm-slim
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
ffmpeg \ ffmpeg &&\
git \ rm -rf /var/lib/apt/lists/*
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch # Copy UV binaries
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu COPY --from=uvbin /uv /uvx /bin/
COPY . . # Copy the Python version
COPY --from=builder-cpu --chown=python:python /python /python
# Install WhisperLiveKit directly, allowing for optional dependencies # Copy the virtual environment with all dependencies installed
RUN if [ -n "$EXTRAS" ]; then \ COPY --from=builder-cpu /app/.venv /app/.venv
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000 EXPOSE 8000
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_PYTHON_DOWNLOADS=0
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU # Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"] CMD ["--model", "tiny"]

173
README.md
View File

@@ -1,28 +1,32 @@
<h1 align="center">WhisperLiveKit</h1> <h1 align="center">WLK</h1>
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p>
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p> </p>
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
<p align="center"> <p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
</a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
</p> </p>
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ### Powered by Leading Research:
#### Powered by Leading Research: **See the interactive playground in [this repo](https://github.com/QuentinFuxa/streamlit-d3-network) to explore how AlignAtt works**
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408).
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages. - [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf) - [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization - [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization - [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
- [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) (2025) - 4B-parameter multilingual speech model by Mistral AI
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection - [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
@@ -45,15 +49,18 @@ pip install whisperlivekit
#### Quick Start #### Quick Start
1. **Start the transcription server:** 1. **Start the transcription server:**
```bash ```bash
whisperlivekit-server --model base --language en wlk --model base --language en
``` ```
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time! 2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages. > - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options. > - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Use it to capture audio from web pages. #### Use it to capture audio from web pages.
Go to `chrome-extension` for instructions. Go to `chrome-extension` for instructions.
@@ -66,41 +73,87 @@ Go to `chrome-extension` for instructions.
#### Optional Dependencies #### Optional Dependencies
| Optional | `pip install` | | Feature | `uv sync` | `pip install -e` |
|-----------|-------------| |-----------|-------------|-------------|
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | | **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
| **Apple Silicon optimizations** | `mlx-whisper` | | **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
| **Translation** | `nllw` | | **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` | | **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
| *[Not recommanded]* Original Whisper backend | `whisper` | | **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` | | **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
| OpenAI API backend | `openai` | | **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
See **Parameters & Configuration** below on how to use them. Supported GPU profiles:
```bash
# Profile A: Sortformer diarization
uv sync --extra cu129 --extra diarization-sortformer
# Profile B: Voxtral HF + translation
uv sync --extra cu129 --extra voxtral-hf --extra translation
```
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
See **Parameters & Configuration** below on how to use them.
<p align="center">
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
</p>
See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more.
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
### Voxtral Backend
WhisperLiveKit supports [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602),
a 4B-parameter speech model from Mistral AI that natively handles 100+ languages with automatic
language detection. Whisper also supports auto-detection (`--language auto`), but Voxtral's per-chunk
detection is more reliable and does not bias towards English.
```bash
# Apple Silicon (native MLX, recommended)
pip install -e ".[voxtral-mlx]"
wlk --backend voxtral-mlx
# Linux/GPU (HuggingFace transformers)
pip install transformers torch
wlk --backend voxtral
```
Voxtral uses its own streaming policy and does not use LocalAgreement or SimulStreaming.
See [BENCHMARK.md](BENCHMARK.md) for performance numbers.
### Usage Examples ### Usage Examples
**Command-line Interface**: Start the transcription server with various options: **Command-line Interface**: Start the transcription server with various options:
```bash ```bash
# Large model and translate from french to danish # Large model and translate from french to danish
whisperlivekit-server --model large-v3 --language fr --target-language da wlk --model large-v3 --language fr --target-language da
# Diarization and server listening on */80 # Diarization and server listening on */80
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
# Voxtral multilingual (auto-detects language)
wlk --backend voxtral-mlx
``` ```
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes. **Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
```python ```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
transcription_engine = None transcription_engine = None
@@ -139,15 +192,15 @@ async def websocket_endpoint(websocket: WebSocket):
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | .pt file/directory containing whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translate to using NLLB. Ex: `fr`. [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--no-vac` | Disable Voice Activity Controller | `False` | | `--backend` | ASR backend selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. Options: `mlx-whisper`, `faster-whisper`, `whisper`, `openai-api` (LocalAgreement only), `voxtral-mlx` (Apple Silicon), `voxtral` (HuggingFace) | `auto` |
| `--no-vad` | Disable Voice Activity Detection | `False` | | `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--host` | Server host address | `localhost` | | `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` | | `--port` | Server port | `8000` |
@@ -155,6 +208,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` | | `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` | | `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
| Translation options | Description | Default | | Translation options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
@@ -164,14 +218,15 @@ async def websocket_endpoint(websocket: WebSocket):
| Diarization options | Description | Default | | Diarization options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` | | `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` | | `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` | | `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` | | `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
| SimulStreaming backend options | Description | Default | | SimulStreaming backend options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` | | `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` | | `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
| `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -181,8 +236,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--never-fire` | Never truncate incomplete words | `False` | | `--never-fire` | Never truncate incomplete words | `False` |
| `--init-prompt` | Initial prompt for the model | `None` | | `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` | | `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
@@ -241,7 +295,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk
**CPU only:** **CPU only:**
```bash ```bash
docker build -f Dockerfile.cpu -t wlk . docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
docker run -p 8000:8000 --name wlk wlk docker run -p 8000:8000 --name wlk wlk
``` ```
@@ -253,6 +307,18 @@ docker run -p 8000:8000 --name wlk wlk
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
``` ```
**Compose (recommended for cache + token wiring):**
```bash
# GPU Sortformer profile
docker compose up --build wlk-gpu-sortformer
# GPU Voxtral profile
docker compose up --build wlk-gpu-voxtral
# CPU service
docker compose up --build wlk-cpu
```
### Memory Requirements ### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated - **Large models**: Ensure your Docker runtime has sufficient memory allocated
@@ -260,9 +326,34 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization #### Customization
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start - `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models - `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
## 🔮 Use Cases ## Testing & Benchmarks
WhisperLiveKit includes a unit test suite and an offline benchmark harness.
```bash
# Install test dependencies
pip install -e ".[test]"
# Run unit tests (no model download required)
pytest tests/ -v
# Benchmark a single backend
python test_backend_offline.py --backend faster-whisper --no-realtime
# Benchmark all installed backends
python test_backend_offline.py --benchmark --no-realtime
# Export benchmark results as JSON
python test_backend_offline.py --benchmark --no-realtime --json results.json
```
See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and
timestamp accuracy on Apple Silicon.
## Use Cases
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service... Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...

View File

@@ -1,258 +0,0 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 KiB

After

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -0,0 +1,97 @@
[
{
"word": "This",
"start": 0.0,
"end": 0.24
},
{
"word": "is",
"start": 0.24,
"end": 0.56
},
{
"word": "a",
"start": 0.56,
"end": 0.76
},
{
"word": "transcription",
"start": 0.76,
"end": 1.32
},
{
"word": "test.",
"start": 1.32,
"end": 2.0
},
{
"word": "We",
"start": 2.4,
"end": 2.5
},
{
"word": "want",
"start": 2.5,
"end": 2.66
},
{
"word": "to",
"start": 2.66,
"end": 2.84
},
{
"word": "see",
"start": 2.84,
"end": 3.1
},
{
"word": "if",
"start": 3.1,
"end": 3.34
},
{
"word": "we",
"start": 3.34,
"end": 3.5
},
{
"word": "can",
"start": 3.5,
"end": 3.68
},
{
"word": "use",
"start": 3.68,
"end": 4.04
},
{
"word": "smaller",
"start": 4.04,
"end": 4.76
},
{
"word": "chunks.",
"start": 4.76,
"end": 5.16
},
{
"word": "What",
"start": 6.06,
"end": 6.32
},
{
"word": "do",
"start": 6.32,
"end": 6.44
},
{
"word": "you",
"start": 6.44,
"end": 6.58
},
{
"word": "think?",
"start": 6.58,
"end": 6.84
}
]

View File

@@ -0,0 +1,177 @@
[
{
"word": "Ok,",
"start": 2.02,
"end": 2.38
},
{
"word": "là",
"start": 2.52,
"end": 2.58
},
{
"word": "c",
"start": 2.58,
"end": 2.74
},
{
"word": "'est",
"start": 2.74,
"end": 2.76
},
{
"word": "un",
"start": 2.76,
"end": 2.86
},
{
"word": "test,",
"start": 2.86,
"end": 3.2
},
{
"word": "on",
"start": 3.34,
"end": 3.34
},
{
"word": "veut",
"start": 3.34,
"end": 3.48
},
{
"word": "voir",
"start": 3.48,
"end": 3.86
},
{
"word": "si",
"start": 3.86,
"end": 4.14
},
{
"word": "ça",
"start": 4.14,
"end": 4.26
},
{
"word": "arrive",
"start": 4.26,
"end": 4.36
},
{
"word": "à",
"start": 4.36,
"end": 4.5
},
{
"word": "capté",
"start": 4.5,
"end": 4.78
},
{
"word": "le",
"start": 4.78,
"end": 4.9
},
{
"word": "silence.",
"start": 4.9,
"end": 5.44
},
{
"word": "Là",
"start": 9.24,
"end": 9.6
},
{
"word": "il",
"start": 9.6,
"end": 9.78
},
{
"word": "est",
"start": 9.78,
"end": 9.84
},
{
"word": "une",
"start": 9.84,
"end": 9.96
},
{
"word": "telle",
"start": 9.96,
"end": 10.12
},
{
"word": "seconde",
"start": 10.12,
"end": 10.38
},
{
"word": "de",
"start": 10.38,
"end": 10.48
},
{
"word": "silence",
"start": 10.48,
"end": 10.78
},
{
"word": "et",
"start": 10.78,
"end": 11.06
},
{
"word": "je",
"start": 11.06,
"end": 11.16
},
{
"word": "vous",
"start": 11.16,
"end": 11.32
},
{
"word": "parle.",
"start": 11.32,
"end": 11.68
},
{
"word": "Et",
"start": 13.28,
"end": 13.64
},
{
"word": "voilà,",
"start": 13.64,
"end": 13.96
},
{
"word": "allez",
"start": 14.36,
"end": 14.62
},
{
"word": "on",
"start": 14.62,
"end": 14.78
},
{
"word": "va",
"start": 14.78,
"end": 14.88
},
{
"word": "tester",
"start": 14.88,
"end": 15.06
},
{
"word": "ça.",
"start": 15.06,
"end": 15.36
}
]

View File

@@ -0,0 +1,382 @@
[
{
"word": "Transcription",
"start": 0.0,
"end": 0.6
},
{
"word": "technology",
"start": 0.6,
"end": 1.24
},
{
"word": "has",
"start": 1.24,
"end": 1.5
},
{
"word": "improved",
"start": 1.5,
"end": 1.96
},
{
"word": "so",
"start": 1.96,
"end": 2.32
},
{
"word": "much",
"start": 2.32,
"end": 2.68
},
{
"word": "in",
"start": 2.68,
"end": 2.94
},
{
"word": "the",
"start": 2.94,
"end": 3.02
},
{
"word": "past",
"start": 3.02,
"end": 3.24
},
{
"word": "few",
"start": 3.24,
"end": 3.5
},
{
"word": "years.",
"start": 3.5,
"end": 3.96
},
{
"word": "Have",
"start": 4.56,
"end": 4.74
},
{
"word": "you",
"start": 4.74,
"end": 4.9
},
{
"word": "noticed",
"start": 4.9,
"end": 5.26
},
{
"word": "how",
"start": 5.26,
"end": 5.52
},
{
"word": "accurate",
"start": 5.52,
"end": 6.08
},
{
"word": "real",
"start": 6.08,
"end": 6.42
},
{
"word": "-time",
"start": 6.42,
"end": 6.74
},
{
"word": "speech",
"start": 6.74,
"end": 7.24
},
{
"word": "to",
"start": 7.24,
"end": 7.46
},
{
"word": "text",
"start": 7.46,
"end": 7.78
},
{
"word": "is",
"start": 7.78,
"end": 8.0
},
{
"word": "now?",
"start": 8.0,
"end": 8.3
},
{
"word": "Absolutely.",
"start": 8.7,
"end": 9.16
},
{
"word": "I",
"start": 10.04,
"end": 10.38
},
{
"word": "use",
"start": 10.38,
"end": 10.56
},
{
"word": "it",
"start": 10.56,
"end": 10.76
},
{
"word": "all",
"start": 10.76,
"end": 10.9
},
{
"word": "the",
"start": 10.9,
"end": 11.04
},
{
"word": "time",
"start": 11.04,
"end": 11.32
},
{
"word": "for",
"start": 11.32,
"end": 11.54
},
{
"word": "taking",
"start": 11.54,
"end": 11.86
},
{
"word": "notes",
"start": 11.86,
"end": 12.16
},
{
"word": "during",
"start": 12.16,
"end": 12.54
},
{
"word": "meetings.",
"start": 12.54,
"end": 12.94
},
{
"word": "It's",
"start": 13.6,
"end": 13.8
},
{
"word": "amazing",
"start": 13.8,
"end": 14.1
},
{
"word": "how",
"start": 14.1,
"end": 14.48
},
{
"word": "it",
"start": 14.48,
"end": 14.62
},
{
"word": "can",
"start": 14.62,
"end": 14.74
},
{
"word": "recognise",
"start": 14.74,
"end": 15.24
},
{
"word": "different",
"start": 15.24,
"end": 15.68
},
{
"word": "speakers",
"start": 15.68,
"end": 16.16
},
{
"word": "and",
"start": 16.16,
"end": 16.8
},
{
"word": "even",
"start": 16.8,
"end": 17.1
},
{
"word": "add",
"start": 17.1,
"end": 17.44
},
{
"word": "punctuation.",
"start": 17.44,
"end": 18.36
},
{
"word": "Yeah,",
"start": 18.88,
"end": 19.16
},
{
"word": "but",
"start": 19.36,
"end": 19.52
},
{
"word": "sometimes",
"start": 19.52,
"end": 20.16
},
{
"word": "noise",
"start": 20.16,
"end": 20.54
},
{
"word": "can",
"start": 20.54,
"end": 20.8
},
{
"word": "still",
"start": 20.8,
"end": 21.1
},
{
"word": "cause",
"start": 21.1,
"end": 21.44
},
{
"word": "mistakes.",
"start": 21.44,
"end": 21.94
},
{
"word": "Does",
"start": 22.68,
"end": 22.9
},
{
"word": "this",
"start": 22.9,
"end": 23.12
},
{
"word": "system",
"start": 23.12,
"end": 23.46
},
{
"word": "handle",
"start": 23.46,
"end": 23.88
},
{
"word": "that",
"start": 23.88,
"end": 24.12
},
{
"word": "well?",
"start": 24.12,
"end": 24.42
},
{
"word": "It",
"start": 24.42,
"end": 25.32
},
{
"word": "does",
"start": 25.32,
"end": 25.48
},
{
"word": "a",
"start": 25.48,
"end": 25.62
},
{
"word": "pretty",
"start": 25.62,
"end": 25.88
},
{
"word": "good",
"start": 25.88,
"end": 26.08
},
{
"word": "job",
"start": 26.08,
"end": 26.32
},
{
"word": "filtering",
"start": 26.32,
"end": 26.8
},
{
"word": "noise,",
"start": 26.8,
"end": 27.18
},
{
"word": "especially",
"start": 27.36,
"end": 28.0
},
{
"word": "with",
"start": 28.0,
"end": 28.28
},
{
"word": "models",
"start": 28.28,
"end": 28.62
},
{
"word": "that",
"start": 28.62,
"end": 28.94
},
{
"word": "use",
"start": 28.94,
"end": 29.22
},
{
"word": "voice",
"start": 29.22,
"end": 29.54
},
{
"word": "active.",
"start": 29.54,
"end": 29.9
}
]

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""Generate word-level timestamped transcripts using faster-whisper (offline).
Produces one JSON file per audio with: [{word, start, end}, ...]
"""
import json
import os
from faster_whisper import WhisperModel
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
FILES = [
("00_00_07_english_1_speaker.wav", "en"),
("00_00_16_french_1_speaker.wav", "fr"),
("00_00_30_english_3_speakers.wav", "en"),
]
def main():
print("Loading faster-whisper model (base, cpu, float32)...")
model = WhisperModel("base", device="cpu", compute_type="float32")
for filename, lang in FILES:
audio_path = os.path.join(AUDIO_DIR, filename)
out_path = os.path.join(
AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json"
)
print(f"\n{'='*60}")
print(f"Transcribing: {filename} (language={lang})")
print(f"{'='*60}")
segments, info = model.transcribe(
audio_path, word_timestamps=True, language=lang
)
words = []
for segment in segments:
if segment.words:
for w in segment.words:
words.append({
"word": w.word.strip(),
"start": round(w.start, 3),
"end": round(w.end, 3),
})
print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(words, f, indent=2, ensure_ascii=False)
print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}")
print("\nDone.")
if __name__ == "__main__":
main()

BIN
benchmark_chart.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

BIN
benchmark_scatter.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

View File

@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension ## Running this extension
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory. 1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension. 2. Load the `chrome-extension` directory in Chrome as an unpacked extension.

52
compose.yml Normal file
View File

@@ -0,0 +1,52 @@
services:
wlk-gpu-sortformer:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
image: wlk:gpu-sortformer
gpus: all
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--model", "medium", "--diarization", "--pcm-input"]
wlk-gpu-voxtral:
build:
context: .
dockerfile: Dockerfile
args:
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
image: wlk:gpu-voxtral
gpus: all
ports:
- "8001:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
command: ["--backend", "voxtral", "--pcm-input"]
wlk-cpu:
build:
context: .
dockerfile: Dockerfile.cpu
args:
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
image: wlk:cpu
ports:
- "8000:8000"
volumes:
- hf-cache:/root/.cache/huggingface/hub
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
environment:
- HF_TOKEN
volumes:
hf-cache:

View File

@@ -0,0 +1,71 @@
### Alignment between STT Tokens and Diarization Segments
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
> `#` Is the split between the `t-1` prediction and `t` prediction.
## Example 1:
```text
punctuations_segments : __#_______.__________________!____
diarization_segments:
SPK1 __#____________
SPK2 # ___________________
-->
ALIGNED SPK1 __#_______.
ALIGNED SPK2 # __________________!____
t-1 output:
SPK1: __#
SPK2: NO
DIARIZATION BUFFER: NO
t output:
SPK1: __#__.
SPK2: __________________!____
DIARIZATION BUFFER: No
```
## Example 2:
```text
punctuations_segments : _____#__.___________
diarization_segments:
SPK1 ___ #
SPK2 __#______________
-->
ALIGNED SPK1 _____#__.
ALIGNED SPK2 # ___________
t-1 output:
SPK1: ___ #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: __#__.
SPK2: ___________
DIARIZATION BUFFER: No
```
## Example 3:
```text
punctuations_segments : ___.__#__________
diarization_segments:
SPK1 ______#__
SPK2 # ________
-->
ALIGNED SPK1 ___. #
ALIGNED SPK2 __#__________
t-1 output:
SPK1: ___. #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: #
SPK2: __#___________
DIARIZATION BUFFER: NO
```

View File

@@ -1,109 +0,0 @@
# Available Whisper model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
- `largev3turbo`: ~6GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
## Quick Decision Matrix
**Choose 600M**: Limited resources, close to 0 lag
**Choose 1.3B**: Quality matters
**Choose Transformers**: On Apple Silicon

View File

@@ -0,0 +1,106 @@
# Models and Model Paths
## Defaults
**Default Whisper Model**: `base`
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
**Default Model Cache Directory**: `~/.cache/whisper`
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
**Default Translation Model**: `600M` (NLLB-200-distilled)
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
**Default Translation Backend**: `transformers`
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
---
## Available Whisper model sizes:
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
### How to choose?
#### Language Support
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
#### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
_______________________
# Custom Models:
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2

View File

@@ -1,14 +0,0 @@
# Model Path Formats
The `--model-path` parameter accepts:
## File Path
- **`.pt` format only** (required for AlignAtt policy decoder)
## Directory Path (recommended)
Must contain:
- **`.pt` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)

View File

@@ -1,6 +1,114 @@
# Supported Languages # Transcription: Supported Language
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system. WLK supports transcription in the following languages:
| ISO Code | Language Name |
|----------|---------------------|
| en | English |
| zh | Chinese |
| de | German |
| es | Spanish |
| ru | Russian |
| ko | Korean |
| fr | French |
| ja | Japanese |
| pt | Portuguese |
| tr | Turkish |
| pl | Polish |
| ca | Catalan |
| nl | Dutch |
| ar | Arabic |
| sv | Swedish |
| it | Italian |
| id | Indonesian |
| hi | Hindi |
| fi | Finnish |
| vi | Vietnamese |
| he | Hebrew |
| uk | Ukrainian |
| el | Greek |
| ms | Malay |
| cs | Czech |
| ro | Romanian |
| da | Danish |
| hu | Hungarian |
| ta | Tamil |
| no | Norwegian |
| th | Thai |
| ur | Urdu |
| hr | Croatian |
| bg | Bulgarian |
| lt | Lithuanian |
| la | Latin |
| mi | Maori |
| ml | Malayalam |
| cy | Welsh |
| sk | Slovak |
| te | Telugu |
| fa | Persian |
| lv | Latvian |
| bn | Bengali |
| sr | Serbian |
| az | Azerbaijani |
| sl | Slovenian |
| kn | Kannada |
| et | Estonian |
| mk | Macedonian |
| br | Breton |
| eu | Basque |
| is | Icelandic |
| hy | Armenian |
| ne | Nepali |
| mn | Mongolian |
| bs | Bosnian |
| kk | Kazakh |
| sq | Albanian |
| sw | Swahili |
| gl | Galician |
| mr | Marathi |
| pa | Punjabi |
| si | Sinhala |
| km | Khmer |
| sn | Shona |
| yo | Yoruba |
| so | Somali |
| af | Afrikaans |
| oc | Occitan |
| ka | Georgian |
| be | Belarusian |
| tg | Tajik |
| sd | Sindhi |
| gu | Gujarati |
| am | Amharic |
| yi | Yiddish |
| lo | Lao |
| uz | Uzbek |
| fo | Faroese |
| ht | Haitian Creole |
| ps | Pashto |
| tk | Turkmen |
| nn | Nynorsk |
| mt | Maltese |
| sa | Sanskrit |
| lb | Luxembourgish |
| my | Myanmar |
| bo | Tibetan |
| tl | Tagalog |
| mg | Malagasy |
| as | Assamese |
| tt | Tatar |
| haw | Hawaiian |
| ln | Lingala |
| ha | Hausa |
| ba | Bashkir |
| jw | Javanese |
| su | Sundanese |
| yue | Cantonese |
# Translation: Supported Languages
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
## How to Specify Languages ## How to Specify Languages

View File

@@ -0,0 +1,43 @@
# Technical Integration Guide
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
---
## 1. Runtime Components
| Layer | File(s) | Purpose |
|-------|---------|---------|
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
---
## 2. Running Without the Bundled Frontend
1. Start the server/engine however you like:
```bash
wlk --model small --language en --host 0.0.0.0 --port 9000
# or launch your own app that instantiates TranscriptionEngine(...)
```
2. Build your own client (browser, mobile, desktop) that:
- Opens `ws(s)://<host>:<port>/asr`
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
- Consumes the JSON payload defined in `docs/API.md`.
---
## 3. Running Without FastAPI
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.

140
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,140 @@
# Troubleshooting
## GPU drivers & cuDNN visibility
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
> Reported in issue #271 (Arch/CachyOS)
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
```bash
sudo pacman -S cuda cudnn
sudo ldconfig
```
2. **Make sure the shared objects are visible**:
```bash
ls /usr/lib/libcudnn*
```
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
```bash
python - <<'EOF'
import torch
print(torch.version.cuda)
EOF
nvcc --version
```
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
### Windows error: `Could not locate cudnn_ops64_9.dll`
> Reported in issue #286 (Conda on Windows)
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
1. Locate the DLL shipped with PyTorch, e.g.
```
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
```
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
3. Restart the shell before launching `wlk`.
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
---
## PyTorch / CTranslate2 GPU builds
### `Torch not compiled with CUDA enabled`
> Reported in issue #284
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
Install the GPU-enabled wheels that match your CUDA toolkit:
```bash
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
```
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
Validate with:
```python
import torch
print(torch.cuda.is_available(), torch.cuda.get_device_name())
```
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
> Follow-up in issue #284
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
```bash
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
```
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
3. Verify:
```python
import ctranslate2
print("CUDA devices:", ctranslate2.get_cuda_device_count())
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
```
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
---
## Hopper / Blackwell (`sm_121a`) systems
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
```bash
# Find your Python environment's Triton directory
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
# Copy the system ptxas to Triton's backend directory
# Replace <triton_path> with the output above
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
```
For example, in a virtual environment:
```bash
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
```
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
### Alternative: Environment variable approach
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
```bash
export CUDA_HOME="/usr/local/cuda-13.0"
export PATH="$CUDA_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
# Tell Triton where the new ptxas lives
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
# Force PyTorch to JIT kernels for all needed architectures
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
```
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
---
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.

View File

@@ -4,66 +4,138 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.13" version = "0.2.19"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [{ name = "Quentin Fuxa" }]
{ name = "Quentin Fuxa" }
]
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.11, <3.14"
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech" "Topic :: Multimedia :: Sound/Audio :: Speech",
] ]
dependencies = [ dependencies = [
"fastapi", "fastapi",
"librosa", "librosa",
"soundfile", "soundfile",
"faster-whisper",
"uvicorn", "uvicorn",
"websockets", "websockets",
"torchaudio>=2.0.0", "huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"torch>=2.0.0", "torch>=2.0.0",
"torchaudio>=2.0.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
] ]
[project.optional-dependencies] [project.optional-dependencies]
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
translation = ["nllw"] translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"] sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
mlx-whisper = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
]
voxtral-mlx = [
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
"mistral-common[audio]",
]
voxtral-hf = [
"transformers>=5.2.0; python_version >= '3.10'",
"mistral-common[audio]",
"accelerate>=0.12",
]
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
cu129 = [
"torch>=2.0.0",
"torchaudio>=2.0.0",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
]
diarization-sortformer = [
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
]
diarization-diart = [
"diart",
"torch<2.9.0",
"torchaudio<2.9.0",
"torchvision<0.24.0",
]
[dependency-groups]
dev = ["rich>=14.3.3"]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "cu129" },
],
[
{ extra = "diarization-diart" },
{ extra = "cu129" },
],
[
{ extra = "voxtral-hf" },
{ extra = "diarization-sortformer" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchaudio = [
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true
[project.urls] [project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
[project.scripts] [project.scripts]
whisperlivekit-server = "whisperlivekit.basic_server:main" whisperlivekit-server = "whisperlivekit.basic_server:main"
wlk = "whisperlivekit.basic_server:main"
[tool.setuptools] [tool.setuptools]
packages = [ packages = [
"whisperlivekit", "whisperlivekit",
"whisperlivekit.diarization", "whisperlivekit.diarization",
"whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.mlx",
"whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.whisper",
"whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web", "whisperlivekit.web",
"whisperlivekit.whisper_streaming_custom", "whisperlivekit.local_agreement",
"whisperlivekit.vad_models" "whisperlivekit.voxtral_mlx",
"whisperlivekit.silero_vad_models",
] ]
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"] "whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

291
run_benchmark.py Normal file
View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
Comprehensive benchmark runner for WhisperLiveKit.
Tests all available backend+policy combinations across multiple audio files,
model sizes, and VAC on/off configurations. Outputs structured JSON that
is consumed by the report generator.
Usage:
python run_benchmark.py # full benchmark
python run_benchmark.py --quick # subset (tiny models, fewer combos)
python run_benchmark.py --json results.json # custom output path
"""
import argparse
import asyncio
import gc
import json
import logging
import platform
import subprocess
import sys
import time
from dataclasses import asdict
from pathlib import Path
logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger("benchmark")
logger.setLevel(logging.INFO)
# Re-use harness functions
sys.path.insert(0, str(Path(__file__).parent))
from test_backend_offline import (
AUDIO_TESTS_DIR,
SAMPLE_RATE,
TestResult,
create_engine,
discover_audio_files,
download_sample_audio,
load_audio,
run_test,
)
CACHE_DIR = Path(__file__).parent / ".test_cache"
def get_system_info() -> dict:
"""Collect system metadata for the report."""
info = {
"platform": platform.platform(),
"machine": platform.machine(),
"processor": platform.processor(),
"python_version": platform.python_version(),
}
# macOS: get chip info
try:
chip = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"], text=True
).strip()
info["cpu"] = chip
except Exception:
info["cpu"] = platform.processor()
# RAM
try:
mem_bytes = int(
subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()
)
info["ram_gb"] = round(mem_bytes / (1024**3))
except Exception:
info["ram_gb"] = None
# Backend versions
versions = {}
try:
import faster_whisper
versions["faster-whisper"] = faster_whisper.__version__
except ImportError:
pass
try:
import mlx_whisper # noqa: F401
versions["mlx-whisper"] = "installed"
except ImportError:
pass
try:
import mlx.core as mx
versions["mlx"] = mx.__version__
except ImportError:
pass
try:
import transformers
versions["transformers"] = transformers.__version__
except ImportError:
pass
try:
import torch
versions["torch"] = torch.__version__
except ImportError:
pass
info["backend_versions"] = versions
return info
def detect_combos(quick: bool = False) -> list:
"""Build list of (backend, policy, model_size) combos to test."""
combos = []
# Model sizes to test
model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"]
# faster-whisper
try:
import faster_whisper # noqa: F401
for model in model_sizes:
combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model})
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model})
except ImportError:
pass
# mlx-whisper
try:
import mlx_whisper # noqa: F401
for model in model_sizes:
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model})
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model})
except ImportError:
pass
# voxtral-mlx (single model, single policy)
try:
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""})
except ImportError:
pass
# voxtral HF (single model, single policy)
try:
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""})
except ImportError:
pass
return combos
def collect_audio_files() -> list:
"""Collect all benchmark audio files."""
files = []
# audio_tests/ directory
if AUDIO_TESTS_DIR.is_dir():
files.extend(discover_audio_files(str(AUDIO_TESTS_DIR)))
# JFK sample
jfk = CACHE_DIR / "jfk.wav"
if not jfk.exists():
jfk = download_sample_audio()
if jfk.exists():
files.append(jfk)
return files
async def run_single_combo(
combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float,
) -> list:
"""Run one backend+policy+model combo across all audio files."""
backend = combo["backend"]
policy = combo["policy"]
model = combo["model"]
results = []
try:
engine = create_engine(
backend=backend,
model_size=model,
lan=lan,
vac=vac,
policy=policy,
)
# Quiet noisy loggers
for mod in (
"whisperlivekit.audio_processor",
"whisperlivekit.simul_whisper",
"whisperlivekit.tokens_alignment",
"whisperlivekit.simul_whisper.align_att_base",
"whisperlivekit.simul_whisper.simul_whisper",
):
logging.getLogger(mod).setLevel(logging.WARNING)
for audio_path in audio_files:
duration = len(load_audio(str(audio_path))) / SAMPLE_RATE
if duration > max_duration:
logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)")
continue
file_lan = lan
if "french" in audio_path.name.lower() and lan == "en":
file_lan = "fr"
audio = load_audio(str(audio_path))
result = await run_test(
engine, audio, chunk_ms=100, realtime=False,
audio_file=audio_path.name, backend=backend,
policy=policy, lan=file_lan,
)
# Tag with extra metadata
result_dict = asdict(result)
result_dict["model_size"] = model
result_dict["vac"] = vac
results.append(result_dict)
except Exception as e:
logger.error(f" FAILED: {e}")
import traceback
traceback.print_exc()
return results
async def run_full_benchmark(combos, audio_files, max_duration=60.0):
"""Run all combos with VAC on and off."""
all_results = []
total = len(combos) * 2 # x2 for VAC on/off
idx = 0
for combo in combos:
for vac in [True, False]:
idx += 1
vac_str = "VAC=on" if vac else "VAC=off"
desc = f"{combo['backend']} / {combo['policy']}"
if combo["model"]:
desc += f" / {combo['model']}"
desc += f" / {vac_str}"
print(f"\n{'='*70}")
print(f"[{idx}/{total}] {desc}")
print(f"{'='*70}")
results = await run_single_combo(
combo, audio_files, vac=vac, lan="en", max_duration=max_duration,
)
all_results.extend(results)
# Free memory between combos
gc.collect()
return all_results
def main():
parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark")
parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos")
parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path")
parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds")
args = parser.parse_args()
system_info = get_system_info()
combos = detect_combos(quick=args.quick)
audio_files = collect_audio_files()
print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM")
print(f"Backends: {list(system_info['backend_versions'].keys())}")
print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}")
print(f"Audio files: {[f.name for f in audio_files]}")
print()
t0 = time.time()
all_results = asyncio.run(
run_full_benchmark(combos, audio_files, max_duration=args.max_duration)
)
total_time = time.time() - t0
output = {
"system_info": system_info,
"benchmark_date": time.strftime("%Y-%m-%d %H:%M"),
"total_benchmark_time_s": round(total_time, 1),
"n_combos": len(combos) * 2,
"n_audio_files": len(audio_files),
"results": all_results,
}
Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}")
if __name__ == "__main__":
main()

BIN
scripts/alignment_heads.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
Optionally shrink the supported audio chunk length (in seconds) by trimming the
encoder positional embeddings and updating the stored model dimensions.
"""
import argparse
import json
import os
from pathlib import Path
from typing import Dict, Tuple
import torch
from whisperlivekit.whisper import _convert_hf_state_dict
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
from whisperlivekit.whisper.model import ModelDimensions
from whisperlivekit.whisper.utils import exact_div
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
safetensor_path = repo_path / "model.safetensors"
bin_path = repo_path / "pytorch_model.bin"
if safetensor_path.is_file():
try:
from safetensors.torch import load_file # type: ignore
except Exception as exc: # pragma: no cover - import guard
raise RuntimeError(
"Install safetensors to load model.safetensors "
"(pip install safetensors)"
) from exc
return load_file(str(safetensor_path))
if bin_path.is_file():
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
)
def _load_config(repo_path: Path) -> Dict:
config_path = repo_path / "config.json"
if not config_path.is_file():
raise FileNotFoundError(
f"Hugging Face checkpoint at {repo_path} is missing config.json"
)
with open(config_path, "r", encoding="utf-8") as fp:
return json.load(fp)
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
n_samples = int(round(chunk_length * SAMPLE_RATE))
expected_samples = chunk_length * SAMPLE_RATE
if abs(n_samples - expected_samples) > 1e-6:
raise ValueError(
"chunk_length must align with sample rate so that "
"chunk_length * SAMPLE_RATE is an integer"
)
n_frames = exact_div(n_samples, HOP_LENGTH)
n_audio_ctx = exact_div(n_frames, 2)
return n_frames, n_audio_ctx
def _build_dims(config: Dict, chunk_length: float) -> Dict:
base_dims = ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
).__dict__.copy()
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
base_dims["n_audio_ctx"] = n_audio_ctx
base_dims["chunk_length"] = chunk_length
return base_dims
def _trim_positional_embedding(
state_dict: Dict[str, torch.Tensor], target_ctx: int
) -> None:
key = "encoder.positional_embedding"
if key not in state_dict:
raise KeyError(f"{key} missing from converted state dict")
tensor = state_dict[key]
if tensor.shape[0] < target_ctx:
raise ValueError(
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
)
if tensor.shape[0] == target_ctx:
return
state_dict[key] = tensor[:target_ctx].contiguous()
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
state_dict = _load_state_dict(hf_path)
converted = _convert_hf_state_dict(state_dict)
config = _load_config(hf_path)
dims = _build_dims(config, chunk_length)
_trim_positional_embedding(converted, dims["n_audio_ctx"])
package = {"dims": dims, "model_state_dict": converted}
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(package, output_path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
)
parser.add_argument(
"hf_path",
type=str,
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
)
parser.add_argument(
"--output",
type=str,
default="converted-whisper.pt",
help="Destination path for the .pt file",
)
parser.add_argument(
"--chunk-length",
type=float,
default=30.0,
help="Audio chunk length in seconds to support (default: 30)",
)
return parser.parse_args()
def main():
args = parse_args()
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
output_path = Path(os.path.expanduser(args.output)).resolve()
convert_checkpoint(hf_path, output_path, args.chunk_length)
print(f"Saved converted checkpoint to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,294 @@
"""Determine alignment heads for a variants, such as distilled model"""
from __future__ import annotations
import argparse
import base64
import gzip
import io
import math
import pathlib
import sys
from typing import List, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from datasets import Audio as DatasetAudio
from datasets import load_dataset
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(WHISPER_ROOT))
from whisper import load_model
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
def load_dataset_clips(name, config, split, limit):
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
waveform = waveform_np
transcript = str(transcript)
clips.append((waveform, transcript))
return clips
def load_clips(args):
return load_dataset_clips(
args.dataset,
args.dataset_config,
args.dataset_split,
args.dataset_num_samples,
)
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
return waveform
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="pytorch_model.bin",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device to run on",
)
parser.add_argument(
"--dataset",
type=str,
default="librispeech_asr"
)
parser.add_argument(
"--dataset-config",
type=str,
default="clean"
)
parser.add_argument(
"--dataset-split",
type=str,
default="validation[:1%]",
)
parser.add_argument(
"--dataset-num-samples",
type=int,
default=16,
)
parser.add_argument(
"--threshold",
type=float,
default=1.5,
help="Z score threshold for a head to be selected",
)
parser.add_argument(
"--votes",
type=float,
default=0.75,
help="percentage of clips that must vote for a head",
)
parser.add_argument(
"--output",
type=str,
default="alignment_heads.b85",
)
parser.add_argument(
"--visualize-top-k",
type=int,
default=32,
)
return parser.parse_args()
def collect_heads(
model,
tokenizer,
clips: Sequence[Tuple[AudioInput, str]],
threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
strengths = torch.zeros_like(votes)
for audio_source, transcript in clips:
waveform = pad_or_trim(_waveform_from_source(audio_source))
mel = log_mel_spectrogram(waveform, device=device)
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=device,
)
qks = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
for layer_idx, tensor in enumerate(qks):
if tensor is None:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1)
peak = tensor.max(dim=-1).values # [heads, tokens]
strengths[layer_idx] += peak.mean(dim=-1)
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
)
mask = (zscore > 3).any(dim=-1)
votes[layer_idx] += mask.float()
votes /= len(clips)
strengths /= len(clips)
return votes, strengths
def _select_heads_for_visualization(selection, strengths, top_k):
selected = torch.nonzero(selection, as_tuple=False)
if selected.numel() == 0:
return []
entries = [
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
for layer, head in selected
]
entries.sort(key=lambda item: item[2], reverse=True)
return entries[:top_k]
def _extract_heatmaps(
model,
tokenizer,
clip: Tuple[AudioInput, str],
heads: Sequence[Tuple[int, int, float]],
) -> dict:
if not heads:
return {}
target_map = {}
for layer, head, _ in heads:
target_map.setdefault(layer, set()).add(head)
waveform = pad_or_trim(_waveform_from_source(clip[0]))
mel = log_mel_spectrogram(waveform, device=model.device)
transcript = clip[1]
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=model.device,
)
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
heatmaps = {}
for layer_idx, tensor in enumerate(QKs):
if tensor is None or layer_idx not in target_map:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1).cpu()
for head_idx in target_map[layer_idx]:
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
return heatmaps
def _plot_heatmaps(
heads, heatmaps, output_path):
cols = min(3, len(heads))
rows = math.ceil(len(heads) / cols)
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
for idx, (layer, head, score) in enumerate(heads):
ax = axes[idx // cols][idx % cols]
mat = heatmaps.get((layer, head))
if mat is None:
ax.axis("off")
continue
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
ax.set_xlabel("time")
ax.set_ylabel("tokens")
for j in range(len(heads), rows * cols):
axes[j // cols][j % cols].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=200)
plt.close(fig)
def _dump_mask(mask: torch.Tensor, output_path: str):
payload = mask.numpy().astype(np.bool_)
blob = base64.b85encode(gzip.compress(payload.tobytes()))
with open(output_path, "wb") as f:
f.write(blob)
def main():
args = _parse_args()
model = load_model(args.model, device=args.device)
model.eval()
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
clips = load_clips(args)
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
# selection = votes > 0.5
selection = strengths > 0.05
_dump_mask(selection.cpu(), args.output)
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,580 @@
#!/usr/bin/env python3
"""Offline Python support matrix runner for WhisperLiveKit."""
from __future__ import annotations
import argparse
import os
import shlex
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
try:
from rich.console import Console
from rich.table import Table
HAS_RICH = True
except Exception:
HAS_RICH = False
SAMPLE_URL = (
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
)
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
CONSOLE = Console() if HAS_RICH else None
@dataclass(frozen=True)
class MatrixRow:
row_id: str
extras: tuple[str, ...]
backend: str
policy: str
diarization_backend: str
requires_gpu: bool = False
CASES = (
MatrixRow(
row_id="fw-diart-cpu",
extras=("test", "cpu", "diarization-diart"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="diart",
),
MatrixRow(
row_id="fw-sortformer-cpu",
extras=("test", "cpu", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
),
MatrixRow(
row_id="fw-sortformer-gpu",
extras=("test", "cu129", "diarization-sortformer"),
backend="faster-whisper",
policy="simulstreaming",
diarization_backend="sortformer",
requires_gpu=True,
),
MatrixRow(
row_id="voxtral-diart-cpu",
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
backend="voxtral",
policy="voxtral",
diarization_backend="diart",
),
)
EXPECTED_FAILURE_CASES = {
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
}
UNSUPPORTED_CASES = {
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
}
@dataclass(frozen=True)
class CaseResult:
python_version: str
row_id: str
status: Literal["PASS", "FAIL", "N/A"]
reason: str
duration_sec: float
hint: str = ""
log_path: str = ""
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Minimal WhisperLiveKit offline support matrix"
)
parser.add_argument(
"--timeout-sec",
type=int,
default=300,
help="Per-case timeout in seconds (default: 300)",
)
parser.add_argument(
"--logs-dir",
default=str(DEFAULT_LOGS_DIR),
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
)
return parser.parse_args()
def safe_slug(text: str) -> str:
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
def status_style(status: str) -> str:
if status == "PASS":
return "green"
if status == "FAIL":
return "bold red"
if status == "N/A":
return "yellow"
return "white"
def print_line(message: str, style: str | None = None) -> None:
if CONSOLE is None:
print(message)
return
if style:
CONSOLE.print(message, style=style, highlight=False)
else:
CONSOLE.print(message, highlight=False)
def tail_text(text: str | None, max_chars: int = 220) -> str:
if not text:
return ""
normalized = " ".join(text.split())
if len(normalized) <= max_chars:
return normalized
return normalized[-max_chars:]
def run_command(
cmd: list[str],
cwd: Path,
env: dict[str, str],
timeout: int | None = None,
log_path: Path | None = None,
log_section: str | None = None,
) -> subprocess.CompletedProcess[str]:
def _append_log(
*,
command: list[str],
section: str,
returncode: int | None,
stdout: str | None,
stderr: str | None,
timed_out: bool = False,
) -> None:
if log_path is None:
return
log_path.parent.mkdir(parents=True, exist_ok=True)
with log_path.open("a", encoding="utf-8") as f:
f.write(f"\n=== {section} ===\n")
f.write(f"$ {shlex.join(command)}\n")
if timed_out:
f.write("status: timeout\n")
else:
f.write(f"status: exit_code={returncode}\n")
if stdout:
f.write("--- stdout ---\n")
f.write(stdout)
if not stdout.endswith("\n"):
f.write("\n")
if stderr:
f.write("--- stderr ---\n")
f.write(stderr)
if not stderr.endswith("\n"):
f.write("\n")
section = log_section or "command"
try:
proc = subprocess.run(
cmd,
cwd=str(cwd),
env=env,
text=True,
capture_output=True,
check=False,
timeout=timeout,
)
except subprocess.TimeoutExpired as exc:
_append_log(
command=cmd,
section=section,
returncode=None,
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
timed_out=True,
)
raise
_append_log(
command=cmd,
section=section,
returncode=proc.returncode,
stdout=proc.stdout,
stderr=proc.stderr,
)
return proc
def detect_gpu_available() -> bool:
try:
proc = subprocess.run(
["nvidia-smi", "-L"],
text=True,
capture_output=True,
check=False,
timeout=10,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return False
return proc.returncode == 0
def download_sample(repo_root: Path) -> Path:
target = repo_root / SAMPLE_PATH
target.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"curl",
"--fail",
"--location",
"--silent",
"--show-error",
SAMPLE_URL,
"--output",
str(target),
]
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
if proc.returncode != 0:
hint = tail_text(proc.stderr or proc.stdout)
raise RuntimeError(f"sample_download_failed: {hint}")
return target
def sync_case_environment(
repo_root: Path,
python_version: str,
row: MatrixRow,
env_dir: Path,
log_path: Path,
) -> tuple[bool, str]:
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
for extra in row.extras:
cmd.extend(["--extra", extra])
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
proc = run_command(
cmd,
cwd=repo_root,
env=env,
log_path=log_path,
log_section="sync",
)
if proc.returncode != 0:
return False, tail_text(proc.stderr or proc.stdout)
return True, ""
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
if result.status != "FAIL" or not expected_reason:
return result
override_hint = result.hint
if result.reason:
override_hint = (
f"expected_failure_override original_reason={result.reason}; {override_hint}"
if override_hint
else f"expected_failure_override original_reason={result.reason}"
)
return CaseResult(
python_version=result.python_version,
row_id=result.row_id,
status="N/A",
reason=expected_reason,
duration_sec=result.duration_sec,
hint=override_hint,
log_path=result.log_path,
)
def build_offline_command(
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
) -> tuple[list[str], int | None]:
base_cmd = [
"uv",
"run",
"--python",
python_version,
"--no-sync",
"python",
"test_backend_offline.py",
"--backend",
row.backend,
"--policy",
row.policy,
"--audio",
str(sample_audio),
"--model",
"tiny",
"--diarization",
"--diarization-backend",
row.diarization_backend,
"--lan",
"en",
"--no-realtime",
]
if shutil.which("timeout"):
return ["timeout", str(timeout_sec), *base_cmd], None
return base_cmd, timeout_sec
def run_case(
repo_root: Path,
python_version: str,
row: MatrixRow,
sample_audio: Path,
timeout_sec: int,
gpu_available: bool,
logs_dir: Path,
) -> CaseResult:
start = time.monotonic()
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
log_path = logs_dir / f"run-{case_slug}.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text("", encoding="utf-8")
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
if unsupported_reason:
log_path.write_text(
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
encoding="utf-8",
)
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason=unsupported_reason,
duration_sec=0.0,
hint="unsupported_case_precheck",
log_path=str(log_path),
)
if row.requires_gpu and not gpu_available:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="N/A",
reason="gpu_unavailable",
duration_sec=0.0,
hint="nvidia-smi unavailable or failed",
log_path=str(log_path),
)
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
sync_ok, sync_hint = sync_case_environment(
repo_root,
python_version,
row,
env_dir,
log_path=log_path,
)
if not sync_ok:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="dependency_sync_failed",
duration_sec=round(time.monotonic() - start, 3),
hint=sync_hint,
log_path=str(log_path),
)
cmd, process_timeout = build_offline_command(
python_version, row, sample_audio, timeout_sec
)
env = os.environ.copy()
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
if row.requires_gpu:
env.pop("CUDA_VISIBLE_DEVICES", None)
else:
env["CUDA_VISIBLE_DEVICES"] = ""
try:
proc = run_command(
cmd,
cwd=repo_root,
env=env,
timeout=process_timeout,
log_path=log_path,
log_section="offline",
)
except subprocess.TimeoutExpired as exc:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason="offline_timeout",
duration_sec=round(time.monotonic() - start, 3),
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
log_path=str(log_path),
)
hint = tail_text(proc.stderr or proc.stdout)
if proc.returncode == 0:
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="PASS",
reason="ok",
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
return CaseResult(
python_version=python_version,
row_id=row.row_id,
status="FAIL",
reason=reason,
duration_sec=round(time.monotonic() - start, 3),
hint=hint,
log_path=str(log_path),
)
def print_summary(results: list[CaseResult]) -> None:
pass_count = sum(1 for row in results if row.status == "PASS")
fail_count = sum(1 for row in results if row.status == "FAIL")
na_count = sum(1 for row in results if row.status == "N/A")
if CONSOLE is None:
print("\n[matrix] results")
print("python | row | status | reason | duration_s")
print("---|---|---|---|---")
for result in results:
print(
f"{result.python_version} | {result.row_id} | {result.status} | "
f"{result.reason} | {result.duration_sec:.3f}"
)
print(
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
f"na={na_count} total={len(results)}"
)
else:
table = Table(title="Support Matrix Results")
table.add_column("Python", style="cyan", no_wrap=True)
table.add_column("Row", style="white")
table.add_column("Status", no_wrap=True)
table.add_column("Reason")
table.add_column("Duration (s)", justify="right", no_wrap=True)
for result in results:
table.add_row(
result.python_version,
result.row_id,
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
result.reason,
f"{result.duration_sec:.3f}",
)
CONSOLE.print()
CONSOLE.print(table)
CONSOLE.print(
f"[bold]Summary[/bold] "
f"pass=[green]{pass_count}[/green] "
f"fail=[bold red]{fail_count}[/bold red] "
f"na=[yellow]{na_count}[/yellow] "
f"total={len(results)}"
)
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
if diagnostics:
if CONSOLE is None:
print("\n[matrix] diagnostics (failed/n-a cases)")
for row in diagnostics:
print(
f"- py={row.python_version} row={row.row_id} "
f"status={row.status} reason={row.reason}"
)
print(f" hint: {row.hint}")
if row.log_path:
print(f" log: {row.log_path}")
else:
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
diagnostics_table.add_column("Case", style="cyan")
diagnostics_table.add_column("Status", no_wrap=True)
diagnostics_table.add_column("Reason")
diagnostics_table.add_column("Hint")
diagnostics_table.add_column("Log")
for row in diagnostics:
diagnostics_table.add_row(
f"py={row.python_version} {row.row_id}",
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
row.reason,
row.hint,
row.log_path,
)
CONSOLE.print()
CONSOLE.print(diagnostics_table)
def main() -> int:
args = parse_args()
if args.timeout_sec <= 0:
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
return 1
repo_root = Path(__file__).resolve().parents[1]
logs_dir = (repo_root / args.logs_dir).resolve()
logs_dir.mkdir(parents=True, exist_ok=True)
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
try:
sample_audio = download_sample(repo_root)
except Exception as exc: # pragma: no cover - straightforward failure path
if CONSOLE is None:
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
else:
CONSOLE.print(
f"[matrix] sample_download_failed: {exc}",
style="bold red",
highlight=False,
)
return 1
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
gpu_available = detect_gpu_available()
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
results: list[CaseResult] = []
for python_version in PYTHON_VERSIONS:
for row in CASES:
print_line(
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
)
result = run_case(
repo_root=repo_root,
python_version=python_version,
row=row,
sample_audio=sample_audio,
timeout_sec=args.timeout_sec,
gpu_available=gpu_available,
logs_dir=logs_dir,
)
result = apply_expected_failure_policy(result)
results.append(result)
print_line(
f"[matrix] {result.status} py={result.python_version} "
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
style=status_style(result.status),
)
if result.log_path:
print_line(f"[matrix] log={result.log_path}", style="dim")
print_summary(results)
fail_count = sum(1 for row in results if row.status == "FAIL")
return 1 if fail_count else 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,9 +1,11 @@
import shutil """Copy core files from web directory to Chrome extension directory."""
import os import os
import shutil
from pathlib import Path from pathlib import Path
def sync_extension_files(): def sync_extension_files():
"""Copy core files from web directory to Chrome extension directory."""
web_dir = Path("whisperlivekit/web") web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension") extension_dir = Path("chrome-extension")

803
test_backend_offline.py Normal file
View File

@@ -0,0 +1,803 @@
#!/usr/bin/env python3
"""
Offline test harness and benchmark suite for WhisperLiveKit backends.
Simulates a client-server session by feeding audio files as PCM bytes through
the full AudioProcessor pipeline (the same path used by the WebSocket server),
without needing a browser or microphone.
Computes WER (Word Error Rate) and timestamp accuracy when ground truth
transcript files (.transcript.json) are available alongside audio files.
Usage:
# Test with a single audio file:
python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav
# Test all files in audio_tests/:
python test_backend_offline.py --backend faster-whisper --no-realtime
# Override streaming policy:
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime
# Multi-backend benchmark (auto-detects all installed backends):
python test_backend_offline.py --benchmark --no-realtime
# Export results as JSON:
python test_backend_offline.py --benchmark --no-realtime --json results.json
# Insert silence for testing silence handling:
python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0
"""
import argparse
import asyncio
import json
import logging
import sys
import time
import urllib.request
from pathlib import Path
from dataclasses import dataclass, asdict, field
from typing import List, Optional
import numpy as np
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("test_offline")
logger.setLevel(logging.INFO)
SAMPLE_RATE = 16000
JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
CACHE_DIR = Path(__file__).parent / ".test_cache"
AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests"
AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
@dataclass
class WordTimestamp:
"""Word with its start/end time."""
word: str
start: float
end: float
@dataclass
class TestResult:
"""Structured result from a single test run."""
audio_file: str
audio_duration_s: float
backend: str
policy: str
language: str
chunk_ms: int
realtime_pacing: bool
# Timing
processing_time_s: float
rtf: float # real-time factor
# Transcription output
transcription: str
n_lines: int
n_responses: int
# WER metrics (None if no ground truth)
wer: Optional[float] = None
wer_details: Optional[dict] = None
# Timestamp accuracy (None if no ground truth)
timestamp_mae: Optional[float] = None
timestamp_max_delta: Optional[float] = None
timestamp_median_delta: Optional[float] = None
# Word-level timestamps
word_timestamps: List[WordTimestamp] = field(default_factory=list)
# Raw last response
last_response: Optional[dict] = None
def download_sample_audio() -> Path:
"""Download the jfk.wav sample if not cached."""
CACHE_DIR.mkdir(exist_ok=True)
path = CACHE_DIR / "jfk.wav"
if not path.exists():
logger.info(f"Downloading sample audio to {path} ...")
urllib.request.urlretrieve(JFK_WAV_URL, path)
logger.info("Done.")
return path
def load_audio(path: str) -> np.ndarray:
"""Load audio file as float32 mono 16kHz numpy array.
Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa).
"""
ext = Path(path).suffix.lower()
if ext in (".mp3", ".ogg", ".m4a"):
import librosa
audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
return audio.astype(np.float32)
import soundfile as sf
audio, sr = sf.read(path, dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != SAMPLE_RATE:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
return audio
def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray:
"""Insert silence into audio at a given position.
Args:
audio: Float32 mono audio array at SAMPLE_RATE.
silence_sec: Duration of silence to insert in seconds.
position_sec: Position in seconds where silence starts.
Returns:
New audio array with silence inserted.
"""
pos_samples = int(position_sec * SAMPLE_RATE)
silence_samples = int(silence_sec * SAMPLE_RATE)
pos_samples = min(pos_samples, len(audio))
silence = np.zeros(silence_samples, dtype=np.float32)
return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]])
def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
"""Convert float32 audio to s16le PCM bytes (what the browser sends)."""
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
def create_engine(
backend: str, model_size: str, lan: str,
diarization: bool = False,
diarization_backend: str = "",
vac: bool = True,
policy: str = "",
):
"""Create a TranscriptionEngine with the given backend config."""
import gc
from whisperlivekit.core import TranscriptionEngine
# Reset singleton so we get a fresh instance
TranscriptionEngine._instance = None
TranscriptionEngine._initialized = False
gc.collect()
kwargs = dict(
backend=backend,
lan=lan,
pcm_input=True,
vac=vac,
transcription=True,
diarization=diarization,
)
if diarization_backend:
kwargs["diarization_backend"] = diarization_backend
if model_size:
kwargs["model_size"] = model_size
if policy:
kwargs["backend_policy"] = policy
return TranscriptionEngine(**kwargs)
def _extract_text_from_response(response_dict: dict) -> str:
"""Extract full transcription text from a FrontData dict."""
def _strip_or_empty(value: object) -> str:
return value.strip() if isinstance(value, str) else ""
segments = response_dict.get("lines", [])
full_text = " ".join(
text
for seg in segments
if isinstance(seg, dict)
for text in [_strip_or_empty(seg.get("text"))]
if text
)
buf = _strip_or_empty(response_dict.get("buffer_transcription"))
if buf:
full_text = f"{full_text} {buf}".strip() if full_text else buf
return full_text
async def run_test(
engine, audio: np.ndarray, chunk_ms: int, realtime: bool,
audio_file: str = "", backend: str = "", policy: str = "", lan: str = "",
) -> TestResult:
"""
Simulate a client session through the full AudioProcessor pipeline.
1. Create AudioProcessor (one per "client session")
2. Start async pipeline (transcription_processor, results_formatter, etc.)
3. Feed audio as PCM bytes in timed chunks
4. Collect and display FrontData responses
5. Signal EOF and cleanup
"""
from whisperlivekit.audio_processor import AudioProcessor
chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000)
total_samples = len(audio)
audio_duration = total_samples / SAMPLE_RATE
logger.info(
f"Audio: {audio_duration:.2f}s | "
f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | "
f"Steps: {total_samples // chunk_samples + 1} | "
f"Realtime: {realtime}"
)
# --- Server side: create processor and start pipeline ---
processor = AudioProcessor(transcription_engine=engine)
results_generator = await processor.create_tasks()
# Collect results in background (like handle_websocket_results)
all_responses = []
response_count = 0
last_printed_text = ""
async def collect_results():
nonlocal response_count, last_printed_text
async for response in results_generator:
all_responses.append(response)
response_count += 1
d = response.to_dict()
# Only print when transcription text actually changes
current_text = _extract_text_from_response(d)
if current_text and current_text != last_printed_text:
buf = d.get("buffer_transcription")
buf = buf.strip() if isinstance(buf, str) else ""
committed = current_text
if buf and committed.endswith(buf):
committed = committed[:-len(buf)].strip()
# Show committed text + buffer separately
display = committed
if buf:
display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m"
print(f" > {display}", flush=True)
last_printed_text = current_text
result_task = asyncio.create_task(collect_results())
# --- Client side: feed audio as PCM bytes ---
t_start = time.time()
for offset in range(0, total_samples, chunk_samples):
chunk = audio[offset : offset + chunk_samples]
pcm_bytes = float32_to_s16le_bytes(chunk)
await processor.process_audio(pcm_bytes)
if realtime:
await asyncio.sleep(chunk_ms / 1000)
feed_elapsed = time.time() - t_start
logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...")
# Signal end of audio (like client disconnect / empty message)
await processor.process_audio(None)
# Wait for pipeline to drain completely
try:
await asyncio.wait_for(result_task, timeout=120.0)
except asyncio.TimeoutError:
logger.warning("Timed out waiting for results. Proceeding with cleanup.")
result_task.cancel()
try:
await result_task
except asyncio.CancelledError:
pass
# --- Capture word-level timestamps before cleanup ---
word_timestamps = []
try:
state = await processor.get_current_state()
for token in state.tokens:
if hasattr(token, 'start') and hasattr(token, 'text') and token.text:
word_timestamps.append(WordTimestamp(
word=token.text.strip(),
start=round(token.start, 3),
end=round(token.end, 3),
))
except Exception as e:
logger.warning(f"Could not capture word timestamps: {e}")
# Cleanup
await processor.cleanup()
total_elapsed = time.time() - t_start
# --- Build result ---
transcription = ""
n_lines = 0
last_response_dict = None
if all_responses:
last = all_responses[-1].to_dict()
last_response_dict = last
n_lines = len(last.get("lines", []))
transcription = _extract_text_from_response(last)
# --- Compute WER and timestamp accuracy against ground truth ---
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
wer_val = None
wer_details = None
ts_mae = None
ts_max_delta = None
ts_median_delta = None
gt_path = Path(audio_file).with_suffix(".transcript.json")
if not gt_path.exists():
gt_path = AUDIO_TESTS_DIR / gt_path
gt = None
if gt_path.exists():
with open(gt_path) as f:
gt = json.load(f)
# WER
gt_text = " ".join(w["word"] for w in gt)
wer_result = compute_wer(gt_text, transcription)
wer_val = round(wer_result["wer"], 4)
wer_details = wer_result
# Timestamp accuracy
if word_timestamps:
pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps]
ts_result = compute_timestamp_accuracy(pred_dicts, gt)
ts_mae = ts_result["mae_start"]
ts_max_delta = ts_result["max_delta_start"]
ts_median_delta = ts_result["median_delta_start"]
result = TestResult(
audio_file=audio_file,
audio_duration_s=round(audio_duration, 2),
backend=backend,
policy=policy,
language=lan,
chunk_ms=chunk_ms,
realtime_pacing=realtime,
processing_time_s=round(total_elapsed, 2),
rtf=round(total_elapsed / audio_duration, 2),
transcription=transcription,
n_lines=n_lines,
n_responses=response_count,
wer=wer_val,
wer_details=wer_details,
timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None,
timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None,
timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None,
word_timestamps=word_timestamps,
last_response=last_response_dict,
)
# --- Print summary ---
print(f"\n{'=' * 60}")
print(f"RESULT: {audio_file}")
print(f"{'=' * 60}")
print(f"Transcription: {transcription}")
print(f"Lines: {n_lines} | Responses: {response_count}")
print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x")
if wer_val is not None:
print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})")
# Print word timestamps if available
if word_timestamps:
print(f"\nWord timestamps ({len(word_timestamps)} words):")
for wt in word_timestamps:
print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}")
# Detailed comparison with ground truth
if gt:
print(f"\n vs Ground truth ({len(gt)} words):")
max_words = max(len(word_timestamps), len(gt))
for i in range(max_words):
pred = word_timestamps[i] if i < len(word_timestamps) else None
ref = gt[i] if i < len(gt) else None
p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30
r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else ""
delta = ""
if pred and ref:
d = pred.start - ref['start']
delta = f" Δstart={d:+.2f}"
print(f" {p_str} | {r_str}{delta}")
if ts_mae is not None:
print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s")
print(f"{'=' * 60}")
return result
def discover_audio_files(directory: str) -> List[Path]:
"""Find all supported audio files in directory."""
d = Path(directory)
files = sorted(
p for p in d.iterdir()
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
)
return files
async def run_all_tests(
engine, audio_files: List[Path], chunk_ms: int, realtime: bool,
backend: str, policy: str, lan: str, max_duration: float = 60.0,
silence_insertions: Optional[List[List[float]]] = None,
) -> List[TestResult]:
"""Run tests on multiple audio files sequentially."""
results = []
for audio_path in audio_files:
# Detect language from filename if "french" in name
file_lan = lan
if "french" in audio_path.name.lower() and lan == "en":
file_lan = "fr"
logger.info(f"Auto-detected language 'fr' from filename")
audio = load_audio(str(audio_path))
# Insert silence segments (applied in reverse position order to keep offsets valid)
if silence_insertions:
for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True):
logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s")
audio = insert_silence(audio, secs, at_sec)
duration = len(audio) / SAMPLE_RATE
if duration > max_duration:
logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)")
continue
print(f"\n{'#' * 60}")
print(f"# Testing: {audio_path.name} ({duration:.1f}s)")
print(f"{'#' * 60}")
result = await run_test(
engine, audio, chunk_ms, realtime,
audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan,
)
results.append(result)
return results
def print_benchmark_summary(results: List[TestResult]):
"""Print a tabular summary of all test results."""
print(f"\n{'=' * 110}")
print("BENCHMARK SUMMARY")
print(f"{'=' * 110}")
print(
f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} "
f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}"
)
print(f"{'-' * 110}")
for r in results:
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
print(
f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s "
f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}"
)
print(f"{'-' * 110}")
total_audio = sum(r.audio_duration_s for r in results)
total_time = sum(r.processing_time_s for r in results)
avg_rtf = total_time / total_audio if total_audio > 0 else 0
wer_vals = [r.wer for r in results if r.wer is not None]
avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None]
avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
print(
f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s "
f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}"
)
print(f"{'=' * 110}")
# Print transcription excerpts
print(f"\nTRANSCRIPTIONS:")
print(f"{'-' * 110}")
for r in results:
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
print(f" {r.audio_file}:")
print(f" {excerpt}")
print(f"{'=' * 110}")
def detect_available_backends() -> List[dict]:
"""Probe which backends can be imported and return (backend, policy) combos.
Returns list of dicts with keys: backend, policy, description.
"""
combos = []
# faster-whisper
try:
import faster_whisper # noqa: F401
combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"})
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"})
except ImportError:
pass
# mlx-whisper (macOS only)
try:
import mlx_whisper # noqa: F401
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"})
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"})
except ImportError:
pass
# openai-whisper
try:
import whisper # noqa: F401
combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"})
combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"})
except ImportError:
pass
# voxtral-mlx
try:
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"})
except ImportError:
pass
# voxtral (HuggingFace)
try:
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"})
except ImportError:
pass
return combos
def print_cross_backend_comparison(all_results: List[TestResult]):
"""Print a comparison table across backends and policies."""
print(f"\n{'=' * 110}")
print("CROSS-BACKEND BENCHMARK COMPARISON")
print(f"{'=' * 110}")
print(
f"{'Backend':<18} {'Policy':<16} {'File':<30} "
f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}"
)
print(f"{'-' * 110}")
for r in all_results:
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
rtf_str = f"{r.rtf:.2f}x"
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -"
# Truncate filename for readability
fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file
print(
f"{r.backend:<18} {r.policy:<16} {fname:<30} "
f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}"
)
print(f"{'-' * 110}")
# Per-backend averages
from collections import defaultdict
by_combo = defaultdict(list)
for r in all_results:
by_combo[(r.backend, r.policy)].append(r)
print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}")
print(f"{'-' * 80}")
for (backend, policy), group in sorted(by_combo.items()):
wer_vals = [r.wer for r in group if r.wer is not None]
rtf_vals = [r.rtf for r in group]
mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None]
avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x"
avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
print(
f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}"
)
print(f"{'=' * 110}")
def _quiet_loggers(verbose: bool):
"""Set internal module log levels to reduce noise."""
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
else:
for mod in (
"whisperlivekit.audio_processor", "whisperlivekit.simul_whisper",
"whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base",
"whisperlivekit.simul_whisper.simul_whisper",
):
logging.getLogger(mod).setLevel(logging.WARNING)
async def run_benchmark(
audio_files: List[Path], chunk_ms: int, realtime: bool,
model_size: str, lan: str, max_duration: float, vac: bool,
verbose: bool,
) -> List[TestResult]:
"""Run benchmark across all available backend+policy combinations."""
combos = detect_available_backends()
if not combos:
logger.error("No backends available. Install at least one ASR backend.")
return []
logger.info(f"Detected {len(combos)} backend+policy combinations:")
for c in combos:
logger.info(f" - {c['description']}")
all_results = []
for i, combo in enumerate(combos, 1):
backend = combo["backend"]
policy = combo["policy"]
desc = combo["description"]
print(f"\n{'*' * 70}")
print(f"* BENCHMARK {i}/{len(combos)}: {desc}")
print(f"{'*' * 70}")
try:
engine = create_engine(
backend, model_size, lan, vac=vac, policy=policy,
)
_quiet_loggers(verbose)
results = await run_all_tests(
engine, audio_files, chunk_ms, realtime,
backend=backend, policy=policy, lan=lan,
max_duration=max_duration,
)
all_results.extend(results)
except Exception as e:
logger.error(f"Failed to run {desc}: {e}")
import traceback
traceback.print_exc()
return all_results
def main():
parser = argparse.ArgumentParser(
description="Offline backend test harness (AudioProcessor-level)"
)
parser.add_argument(
"--backend", default="faster-whisper",
help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.",
)
parser.add_argument(
"--policy", default="",
help="Override backend policy: localagreement, simulstreaming, voxtral.",
)
parser.add_argument(
"--audio", default=None,
help="Path to a single audio file (WAV, MP3, FLAC, etc.).",
)
parser.add_argument(
"--audio-dir", default=None,
help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.",
)
parser.add_argument(
"--chunk-ms", type=int, default=100,
help="Chunk size in milliseconds (simulates real-time interval).",
)
parser.add_argument(
"--model", default="", dest="model_size",
help="Model size or HF repo ID.",
)
parser.add_argument("--lan", default="en", help="Language code.")
parser.add_argument(
"--no-realtime", action="store_true",
help="Skip real-time pacing between chunks (faster but less realistic).",
)
parser.add_argument(
"--no-vac", action="store_true",
help="Disable Voice Activity Classification (send all audio without silence filtering).",
)
parser.add_argument(
"--diarization", action="store_true",
help="Enable speaker diarization.",
)
parser.add_argument(
"--diarization-backend",
default="",
choices=["diart", "sortformer"],
help="Diarization backend when --diarization is enabled.",
)
parser.add_argument(
"--benchmark", action="store_true",
help="Run benchmark across all detected backend+policy combinations.",
)
parser.add_argument(
"--json", default=None, dest="json_output",
help="Write structured JSON results to this file.",
)
parser.add_argument(
"--max-duration", type=float, default=60.0,
help="Skip audio files longer than this many seconds (default: 60).",
)
parser.add_argument(
"--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"),
action="append", default=[],
help="Insert SECS of silence at AT_SEC position. Can be repeated. "
"E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0",
)
parser.add_argument(
"-v", "--verbose", action="store_true",
help="Show debug-level logs from all components.",
)
args = parser.parse_args()
realtime = not args.no_realtime
vac = not args.no_vac
# Resolve audio file(s)
if args.audio:
audio_files = [Path(args.audio)]
elif args.audio_dir:
audio_files = discover_audio_files(args.audio_dir)
elif AUDIO_TESTS_DIR.is_dir():
audio_files = discover_audio_files(str(AUDIO_TESTS_DIR))
else:
# Fall back to jfk.wav download
audio_files = [download_sample_audio()]
if not audio_files:
logger.error("No audio files found.")
sys.exit(1)
logger.info(f"Audio files: {[f.name for f in audio_files]}")
if args.benchmark:
# --- Multi-backend benchmark mode ---
all_results = asyncio.run(
run_benchmark(
audio_files, args.chunk_ms, realtime,
args.model_size, args.lan, args.max_duration, vac,
args.verbose,
)
)
if all_results:
print_cross_backend_comparison(all_results)
results = all_results
else:
# --- Single-backend mode ---
policy = args.policy
logger.info(f"Creating {args.backend} engine...")
engine = create_engine(
args.backend, args.model_size, args.lan,
diarization=args.diarization,
diarization_backend=args.diarization_backend,
vac=vac,
policy=policy,
)
logger.info("Engine ready.")
_quiet_loggers(args.verbose)
results = asyncio.run(
run_all_tests(
engine, audio_files, args.chunk_ms, realtime,
args.backend, policy, args.lan,
max_duration=args.max_duration,
silence_insertions=args.insert_silence or None,
)
)
if len(results) > 1:
print_benchmark_summary(results)
# JSON output
if args.json_output and results:
json_results = []
for r in results:
d = asdict(r)
d.pop("last_response", None) # too verbose for summary
json_results.append(d)
Path(args.json_output).write_text(
json.dumps(json_results, indent=2, ensure_ascii=False)
)
logger.info(f"Results written to {args.json_output}")
if __name__ == "__main__":
main()

58
tests/conftest.py Normal file
View File

@@ -0,0 +1,58 @@
"""Shared pytest fixtures for WhisperLiveKit tests."""
import json
from pathlib import Path
from types import SimpleNamespace
import pytest
from whisperlivekit.timed_objects import ASRToken, Silence, Transcript
AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests"
@pytest.fixture
def sample_tokens():
"""A short sequence of ASRToken objects."""
return [
ASRToken(start=0.0, end=0.5, text="Hello"),
ASRToken(start=0.5, end=1.0, text=" world"),
ASRToken(start=1.0, end=1.5, text=" test."),
]
@pytest.fixture
def sample_silence():
"""A completed silence event."""
s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True)
s.compute_duration()
return s
@pytest.fixture
def mock_args():
"""Minimal args namespace for AudioProcessor tests."""
return SimpleNamespace(
diarization=False,
transcription=True,
target_language="",
vac=False,
vac_chunk_size=0.04,
min_chunk_size=0.1,
pcm_input=True,
punctuation_split=False,
backend="faster-whisper",
backend_policy="localagreement",
vad=True,
)
@pytest.fixture
def ground_truth_en():
"""Ground truth transcript for the 7s English audio (if available)."""
path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json"
if path.exists():
with open(path) as f:
return json.load(f)
return None

View File

@@ -0,0 +1,209 @@
"""Tests for AudioProcessor pipeline with mocked ASR backends.
These tests verify the async audio processing pipeline works correctly
without requiring any real ASR models to be loaded.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
import numpy as np
import pytest
from whisperlivekit.timed_objects import ASRToken, Transcript
# ---------------------------------------------------------------------------
# Mock ASR components
# ---------------------------------------------------------------------------
class MockASR:
"""Mock ASR model holder."""
sep = " "
SAMPLING_RATE = 16000
def __init__(self):
self.transcribe_kargs = {}
self.original_language = "en"
self.backend_choice = "mock"
def transcribe(self, audio):
return None
class MockOnlineProcessor:
"""Mock online processor that returns canned tokens."""
SAMPLING_RATE = 16000
def __init__(self, asr=None):
self.asr = asr or MockASR()
self.audio_buffer = np.array([], dtype=np.float32)
self.end = 0.0
self._call_count = 0
self._finished = False
def insert_audio_chunk(self, audio, audio_stream_end_time):
self.audio_buffer = np.append(self.audio_buffer, audio)
self.end = audio_stream_end_time
def process_iter(self, is_last=False):
self._call_count += 1
# Emit a token on every call when we have audio
if len(self.audio_buffer) > 0:
t = self._call_count * 0.5
return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end
return [], self.end
def get_buffer(self):
return Transcript(start=None, end=None, text="")
def start_silence(self):
return [], self.end
def end_silence(self, silence_duration, offset):
pass
def new_speaker(self, change_speaker):
pass
def finish(self):
self._finished = True
return [], self.end
def warmup(self, audio, init_prompt=""):
pass
def _make_pcm_bytes(duration_s=0.1, sample_rate=16000):
"""Generate silent PCM s16le bytes."""
n_samples = int(duration_s * sample_rate)
audio = np.zeros(n_samples, dtype=np.float32)
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_engine():
"""Create a mock TranscriptionEngine-like object."""
engine = SimpleNamespace(
asr=MockASR(),
diarization_model=None,
translation_model=None,
args=SimpleNamespace(
diarization=False,
transcription=True,
target_language="",
vac=False,
vac_chunk_size=0.04,
min_chunk_size=0.1,
pcm_input=True,
punctuation_split=False,
backend="mock",
backend_policy="localagreement",
vad=True,
model_size="base",
lan="en",
),
)
return engine
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestPCMConversion:
"""Test PCM byte conversion without needing the full pipeline."""
def test_s16le_roundtrip(self):
"""Convert float32 → s16le → float32 and verify approximate roundtrip."""
original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32)
s16 = (original * 32768).clip(-32768, 32767).astype(np.int16)
pcm_bytes = s16.tobytes()
# Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float)
recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
np.testing.assert_allclose(recovered, original, atol=1 / 32768)
@pytest.mark.asyncio
class TestPipelineBasics:
async def test_feed_audio_and_get_responses(self, mock_engine):
"""Feed audio through the pipeline and verify we get responses."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
# Feed 2 seconds of audio in 100ms chunks
for _ in range(20):
await processor.process_audio(_make_pcm_bytes(0.1))
# Signal EOF
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
# We should have gotten at least one response
assert len(responses) > 0
async def test_eof_terminates_pipeline(self, mock_engine):
"""Sending None (EOF) should cleanly terminate the pipeline."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
# Send a small amount of audio then EOF
await processor.process_audio(_make_pcm_bytes(0.5))
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
# Pipeline should have terminated without error
assert task.done()
async def test_empty_audio_no_crash(self, mock_engine):
"""Sending EOF immediately (no audio) should not crash."""
from whisperlivekit.audio_processor import AudioProcessor
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
processor = AudioProcessor(transcription_engine=mock_engine)
results_gen = await processor.create_tasks()
responses = []
async def collect():
async for resp in results_gen:
responses.append(resp)
task = asyncio.create_task(collect())
await processor.process_audio(None)
await asyncio.wait_for(task, timeout=10.0)
await processor.cleanup()
assert task.done()

99
tests/test_config.py Normal file
View File

@@ -0,0 +1,99 @@
"""Tests for WhisperLiveKitConfig."""
import logging
from types import SimpleNamespace
import pytest
from whisperlivekit.config import WhisperLiveKitConfig
class TestDefaults:
def test_default_backend(self):
c = WhisperLiveKitConfig()
assert c.backend == "auto"
def test_default_policy(self):
c = WhisperLiveKitConfig()
assert c.backend_policy == "simulstreaming"
def test_default_language(self):
c = WhisperLiveKitConfig()
assert c.lan == "auto"
def test_default_vac(self):
c = WhisperLiveKitConfig()
assert c.vac is True
def test_default_model_size(self):
c = WhisperLiveKitConfig()
assert c.model_size == "base"
def test_default_transcription(self):
c = WhisperLiveKitConfig()
assert c.transcription is True
assert c.diarization is False
class TestPostInit:
def test_en_model_forces_english(self):
c = WhisperLiveKitConfig(model_size="tiny.en")
assert c.lan == "en"
def test_en_suffix_with_auto_language(self):
c = WhisperLiveKitConfig(model_size="base.en", lan="auto")
assert c.lan == "en"
def test_non_en_model_keeps_language(self):
c = WhisperLiveKitConfig(model_size="base", lan="fr")
assert c.lan == "fr"
def test_policy_alias_1(self):
c = WhisperLiveKitConfig(backend_policy="1")
assert c.backend_policy == "simulstreaming"
def test_policy_alias_2(self):
c = WhisperLiveKitConfig(backend_policy="2")
assert c.backend_policy == "localagreement"
def test_policy_no_alias(self):
c = WhisperLiveKitConfig(backend_policy="localagreement")
assert c.backend_policy == "localagreement"
class TestFromNamespace:
def test_known_keys(self):
ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.backend == "faster-whisper"
assert c.lan == "en"
assert c.model_size == "large-v3"
def test_ignores_unknown_keys(self):
ns = SimpleNamespace(backend="auto", unknown_key="value", another="x")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.backend == "auto"
assert not hasattr(c, "unknown_key")
def test_preserves_defaults_for_missing(self):
ns = SimpleNamespace(backend="voxtral-mlx")
c = WhisperLiveKitConfig.from_namespace(ns)
assert c.lan == "auto"
assert c.vac is True
class TestFromKwargs:
def test_known_keys(self):
c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr")
assert c.backend == "mlx-whisper"
assert c.lan == "fr"
def test_warns_on_unknown_keys(self, caplog):
with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"):
c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value")
assert c.backend == "auto"
assert "bogus" in caplog.text
def test_post_init_runs(self):
c = WhisperLiveKitConfig.from_kwargs(model_size="small.en")
assert c.lan == "en"

View File

@@ -0,0 +1,172 @@
"""Tests for HypothesisBuffer — the core of LocalAgreement policy."""
import pytest
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.local_agreement.online_asr import HypothesisBuffer
def make_tokens(words, start=0.0, step=0.5):
"""Helper: create ASRToken list from word strings."""
tokens = []
t = start
for w in words:
tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9))
t += step
return tokens
class TestInsert:
def test_basic_insert(self):
buf = HypothesisBuffer()
tokens = make_tokens(["hello", "world"])
buf.insert(tokens, offset=0.0)
assert len(buf.new) == 2
assert buf.new[0].text == "hello"
def test_insert_with_offset(self):
buf = HypothesisBuffer()
tokens = make_tokens(["hello"], start=0.0)
buf.insert(tokens, offset=5.0)
assert buf.new[0].start == pytest.approx(5.0)
def test_insert_filters_old_tokens(self):
buf = HypothesisBuffer()
buf.last_committed_time = 10.0
tokens = make_tokens(["old", "new"], start=5.0, step=3.0)
buf.insert(tokens, offset=0.0)
# "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered
# "new" at 8.0 is also before 9.9 → filtered
assert len(buf.new) == 0
def test_insert_deduplicates_committed(self):
buf = HypothesisBuffer()
# Commit "hello"
tokens1 = make_tokens(["hello", "world"])
buf.insert(tokens1, offset=0.0)
buf.flush() # commits "hello" (buffer was empty, so nothing matches)
# Actually with empty buffer, flush won't commit anything
# Let's do it properly: two rounds
buf2 = HypothesisBuffer()
first = make_tokens(["hello", "world"])
buf2.insert(first, offset=0.0)
buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"]
second = make_tokens(["hello", "world", "test"])
buf2.insert(second, offset=0.0)
committed = buf2.flush()
# LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"]
assert len(committed) == 2
assert committed[0].text == "hello"
assert committed[1].text == "world"
class TestFlush:
def test_flush_empty(self):
buf = HypothesisBuffer()
committed = buf.flush()
assert committed == []
def test_flush_lcp_matching(self):
buf = HypothesisBuffer()
# Round 1: establish buffer
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush() # buffer = ["hello", "world"], committed = []
# Round 2: same prefix, new suffix
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
committed = buf.flush()
assert [t.text for t in committed] == ["hello", "world"]
def test_flush_no_match(self):
buf = HypothesisBuffer()
# Round 1
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush()
# Round 2: completely different
buf.insert(make_tokens(["foo", "bar"]), offset=0.0)
committed = buf.flush()
assert committed == []
def test_flush_partial_match(self):
buf = HypothesisBuffer()
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
buf.flush()
buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0)
committed = buf.flush()
assert len(committed) == 1
assert committed[0].text == "hello"
def test_flush_updates_last_committed(self):
buf = HypothesisBuffer()
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
buf.flush()
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
buf.flush()
assert buf.last_committed_word == "world"
assert buf.last_committed_time > 0
def test_flush_with_confidence_validation(self):
buf = HypothesisBuffer(confidence_validation=True)
high_conf = [
ASRToken(start=0.0, end=0.5, text="sure", probability=0.99),
ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5),
]
buf.insert(high_conf, offset=0.0)
committed = buf.flush()
# "sure" has p>0.95 → committed immediately
assert len(committed) == 1
assert committed[0].text == "sure"
class TestPopCommitted:
def test_pop_removes_old(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0)
# "a": end=1.0, "b": end=2.0, "c": end=3.0
# pop_committed removes tokens with end <= time
buf.pop_committed(2.0)
# "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains
assert len(buf.committed_in_buffer) == 1
assert buf.committed_in_buffer[0].text == "c"
def test_pop_nothing(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0)
buf.pop_committed(0.0)
assert len(buf.committed_in_buffer) == 2
def test_pop_all(self):
buf = HypothesisBuffer()
buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5)
buf.pop_committed(100.0)
assert len(buf.committed_in_buffer) == 0
class TestStreamingSimulation:
"""Multi-round insert/flush simulating real streaming behavior."""
def test_three_rounds(self):
buf = HypothesisBuffer()
all_committed = []
# Round 1: "this is"
buf.insert(make_tokens(["this", "is"]), offset=0.0)
all_committed.extend(buf.flush())
# Round 2: "this is a test"
buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0)
all_committed.extend(buf.flush())
# Round 3: "this is a test today"
buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0)
all_committed.extend(buf.flush())
words = [t.text for t in all_committed]
assert "this" in words
assert "is" in words
assert "a" in words
assert "test" in words

183
tests/test_metrics.py Normal file
View File

@@ -0,0 +1,183 @@
"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
import pytest
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
class TestNormalizeText:
def test_lowercase(self):
assert normalize_text("Hello World") == "hello world"
def test_strip_punctuation(self):
assert normalize_text("Hello, world!") == "hello world"
def test_collapse_whitespace(self):
assert normalize_text(" hello world ") == "hello world"
def test_keep_hyphens(self):
assert normalize_text("real-time") == "real-time"
def test_keep_apostrophes(self):
assert normalize_text("don't") == "don't"
def test_unicode_normalized(self):
# e + combining accent should be same as precomposed
assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
def test_empty(self):
assert normalize_text("") == ""
def test_only_punctuation(self):
assert normalize_text("...!?") == ""
class TestComputeWER:
def test_perfect_match(self):
result = compute_wer("hello world", "hello world")
assert result["wer"] == 0.0
assert result["substitutions"] == 0
assert result["insertions"] == 0
assert result["deletions"] == 0
def test_case_insensitive(self):
result = compute_wer("Hello World", "hello world")
assert result["wer"] == 0.0
def test_punctuation_ignored(self):
result = compute_wer("Hello, world!", "hello world")
assert result["wer"] == 0.0
def test_one_substitution(self):
result = compute_wer("hello world", "hello earth")
assert result["wer"] == pytest.approx(0.5)
assert result["substitutions"] == 1
def test_one_insertion(self):
result = compute_wer("hello world", "hello big world")
assert result["wer"] == pytest.approx(0.5)
assert result["insertions"] == 1
def test_one_deletion(self):
result = compute_wer("hello big world", "hello world")
assert result["wer"] == pytest.approx(1 / 3)
assert result["deletions"] == 1
def test_completely_different(self):
result = compute_wer("the cat sat", "a dog ran")
assert result["wer"] == pytest.approx(1.0)
def test_empty_reference(self):
result = compute_wer("", "hello")
assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
assert result["ref_words"] == 0
def test_empty_hypothesis(self):
result = compute_wer("hello world", "")
assert result["wer"] == pytest.approx(1.0)
assert result["deletions"] == 2
def test_both_empty(self):
result = compute_wer("", "")
assert result["wer"] == 0.0
def test_ref_and_hyp_word_counts(self):
result = compute_wer("one two three", "one two three four")
assert result["ref_words"] == 3
assert result["hyp_words"] == 4
class TestComputeTimestampAccuracy:
def test_perfect_match(self):
words = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
]
result = compute_timestamp_accuracy(words, words)
assert result["mae_start"] == 0.0
assert result["max_delta_start"] == 0.0
assert result["n_matched"] == 2
def test_constant_offset(self):
ref = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 0.5, "end": 1.0},
]
pred = [
{"word": "hello", "start": 0.1, "end": 0.6},
{"word": "world", "start": 0.6, "end": 1.1},
]
result = compute_timestamp_accuracy(pred, ref)
assert result["mae_start"] == pytest.approx(0.1)
assert result["max_delta_start"] == pytest.approx(0.1)
assert result["n_matched"] == 2
def test_mismatched_word_counts(self):
ref = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "beautiful", "start": 0.5, "end": 1.0},
{"word": "world", "start": 1.0, "end": 1.5},
]
pred = [
{"word": "hello", "start": 0.0, "end": 0.5},
{"word": "world", "start": 1.1, "end": 1.6},
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 2
assert result["n_ref"] == 3
assert result["n_pred"] == 2
def test_empty_predicted(self):
ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
result = compute_timestamp_accuracy([], ref)
assert result["mae_start"] is None
assert result["n_matched"] == 0
def test_empty_reference(self):
pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
result = compute_timestamp_accuracy(pred, [])
assert result["mae_start"] is None
assert result["n_matched"] == 0
def test_case_insensitive_matching(self):
ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 1
assert result["mae_start"] == pytest.approx(0.1)
def test_median_even_count(self):
"""Median with even number of matched words should average the two middle values."""
ref = [
{"word": "a", "start": 0.0, "end": 0.2},
{"word": "b", "start": 0.5, "end": 0.7},
{"word": "c", "start": 1.0, "end": 1.2},
{"word": "d", "start": 1.5, "end": 1.7},
]
pred = [
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
{"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
{"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
{"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 4
# sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
assert result["median_delta_start"] == pytest.approx(0.25)
def test_median_odd_count(self):
"""Median with odd number of matched words takes the middle value."""
ref = [
{"word": "a", "start": 0.0, "end": 0.2},
{"word": "b", "start": 0.5, "end": 0.7},
{"word": "c", "start": 1.0, "end": 1.2},
]
pred = [
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
{"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
{"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
]
result = compute_timestamp_accuracy(pred, ref)
assert result["n_matched"] == 3
# sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
assert result["median_delta_start"] == pytest.approx(0.2)

View File

@@ -0,0 +1,99 @@
"""Tests for silence handling — state machine and double-counting regression."""
import pytest
from whisperlivekit.timed_objects import Silence
class TestSilenceStateMachine:
"""Test Silence object state transitions."""
def test_initial_state(self):
s = Silence(start=1.0, is_starting=True)
assert s.is_starting is True
assert s.has_ended is False
assert s.duration is None
assert s.end is None
def test_end_silence(self):
s = Silence(start=1.0, is_starting=True)
s.end = 3.0
s.is_starting = False
s.has_ended = True
s.compute_duration()
assert s.duration == pytest.approx(2.0)
def test_very_short_silence(self):
s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True)
s.compute_duration()
assert s.duration == pytest.approx(0.01)
def test_zero_duration_silence(self):
s = Silence(start=5.0, end=5.0)
s.compute_duration()
assert s.duration == pytest.approx(0.0)
class TestSilenceDoubleCounting:
"""Regression tests for the silence double-counting bug.
The bug: _begin_silence and _end_silence both pushed self.current_silence
to the queue. Since they were the same Python object, _end_silence's mutation
affected the already-queued start event. The consumer processed both as
ended silences, doubling the duration.
Fix: _begin_silence now pushes a separate Silence object for the start event.
"""
def test_start_and_end_are_separate_objects(self):
"""Simulate the fix: start event and end event must be different objects."""
# Simulate _begin_silence: creates start event as separate object
current_silence = Silence(start=1.0, is_starting=True)
start_event = Silence(start=1.0, is_starting=True) # separate copy
# Simulate _end_silence: mutates current_silence
current_silence.end = 3.0
current_silence.is_starting = False
current_silence.has_ended = True
current_silence.compute_duration()
# start_event should NOT be affected by mutations to current_silence
assert start_event.is_starting is True
assert start_event.has_ended is False
assert start_event.end is None
# current_silence (end event) has the final state
assert current_silence.has_ended is True
assert current_silence.duration == pytest.approx(2.0)
def test_single_object_would_cause_double_counting(self):
"""Demonstrate the bug: if same object is used for both events."""
shared = Silence(start=1.0, is_starting=True)
queue = [shared] # start event queued
# Mutate (simulates _end_silence)
shared.end = 3.0
shared.is_starting = False
shared.has_ended = True
shared.compute_duration()
queue.append(shared) # end event queued
# Both queue items point to the SAME mutated object
assert queue[0] is queue[1] # same reference
assert queue[0].has_ended is True # start event also shows ended!
# This would cause double-counting: both items have has_ended=True
# and duration=2.0, so the consumer adds 2.0 twice = 4.0
class TestConsecutiveSilences:
def test_multiple_silences(self):
"""Multiple silence periods should have independent durations."""
s1 = Silence(start=1.0, end=2.0)
s1.compute_duration()
s2 = Silence(start=5.0, end=8.0)
s2.compute_duration()
assert s1.duration == pytest.approx(1.0)
assert s2.duration == pytest.approx(3.0)
# Total silence should be sum, not accumulated on single object
assert s1.duration + s2.duration == pytest.approx(4.0)

185
tests/test_timed_objects.py Normal file
View File

@@ -0,0 +1,185 @@
"""Tests for whisperlivekit.timed_objects data classes."""
import pytest
from whisperlivekit.timed_objects import (
ASRToken,
FrontData,
Segment,
Silence,
TimedText,
Transcript,
format_time,
)
class TestFormatTime:
def test_zero(self):
assert format_time(0) == "0:00:00"
def test_one_minute(self):
assert format_time(60) == "0:01:00"
def test_one_hour(self):
assert format_time(3600) == "1:00:00"
def test_fractional_truncated(self):
assert format_time(61.9) == "0:01:01"
class TestASRToken:
def test_with_offset(self):
t = ASRToken(start=1.0, end=2.0, text="hello")
shifted = t.with_offset(0.5)
assert shifted.start == pytest.approx(1.5)
assert shifted.end == pytest.approx(2.5)
assert shifted.text == "hello"
def test_with_offset_preserves_fields(self):
t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95)
shifted = t.with_offset(1.0)
assert shifted.speaker == 2
assert shifted.probability == 0.95
def test_is_silence_false(self):
t = ASRToken(start=0.0, end=1.0, text="hello")
assert t.is_silence() is False
def test_bool_truthy(self):
t = ASRToken(start=0.0, end=1.0, text="hello")
assert bool(t) is True
def test_bool_falsy(self):
t = ASRToken(start=0.0, end=1.0, text="")
assert bool(t) is False
class TestTimedText:
def test_has_punctuation_period(self):
t = TimedText(text="hello.")
assert t.has_punctuation() is True
def test_has_punctuation_exclamation(self):
t = TimedText(text="wow!")
assert t.has_punctuation() is True
def test_has_punctuation_question(self):
t = TimedText(text="really?")
assert t.has_punctuation() is True
def test_has_punctuation_cjk(self):
t = TimedText(text="hello。")
assert t.has_punctuation() is True
def test_no_punctuation(self):
t = TimedText(text="hello world")
assert t.has_punctuation() is False
def test_duration(self):
t = TimedText(start=1.0, end=3.5)
assert t.duration() == pytest.approx(2.5)
def test_contains_timespan(self):
outer = TimedText(start=0.0, end=5.0)
inner = TimedText(start=1.0, end=3.0)
assert outer.contains_timespan(inner) is True
assert inner.contains_timespan(outer) is False
class TestSilence:
def test_compute_duration(self):
s = Silence(start=1.0, end=3.5)
d = s.compute_duration()
assert d == pytest.approx(2.5)
assert s.duration == pytest.approx(2.5)
def test_compute_duration_none_start(self):
s = Silence(start=None, end=3.5)
d = s.compute_duration()
assert d is None
def test_compute_duration_none_end(self):
s = Silence(start=1.0, end=None)
d = s.compute_duration()
assert d is None
def test_is_silence_true(self):
s = Silence()
assert s.is_silence() is True
class TestTranscript:
def test_from_tokens(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, sep="")
assert t.text == "Hello world test."
assert t.start == pytest.approx(0.0)
assert t.end == pytest.approx(1.5)
def test_from_tokens_with_sep(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, sep="|")
assert t.text == "Hello| world| test."
def test_from_empty_tokens(self):
t = Transcript.from_tokens([])
assert t.text == ""
assert t.start is None
assert t.end is None
def test_from_tokens_with_offset(self, sample_tokens):
t = Transcript.from_tokens(sample_tokens, offset=10.0)
assert t.start == pytest.approx(10.0)
assert t.end == pytest.approx(11.5)
class TestSegment:
def test_from_tokens(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
assert seg is not None
assert seg.text == "Hello world test."
assert seg.start == pytest.approx(0.0)
assert seg.end == pytest.approx(1.5)
assert seg.speaker == -1
def test_from_silence_tokens(self):
silences = [
Silence(start=1.0, end=2.0),
Silence(start=2.0, end=3.0),
]
seg = Segment.from_tokens(silences, is_silence=True)
assert seg is not None
assert seg.speaker == -2
assert seg.is_silence() is True
assert seg.text is None
def test_from_empty_tokens(self):
seg = Segment.from_tokens([])
assert seg is None
def test_to_dict(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
d = seg.to_dict()
assert "text" in d
assert "speaker" in d
assert "start" in d
assert "end" in d
class TestFrontData:
def test_to_dict_empty(self):
fd = FrontData()
d = fd.to_dict()
assert d["lines"] == []
assert d["buffer_transcription"] == ""
assert "error" not in d
def test_to_dict_with_error(self):
fd = FrontData(error="something broke")
d = fd.to_dict()
assert d["error"] == "something broke"
def test_to_dict_with_lines(self, sample_tokens):
seg = Segment.from_tokens(sample_tokens)
fd = FrontData(lines=[seg])
d = fd.to_dict()
assert len(d["lines"]) == 1
assert d["lines"][0]["text"] == "Hello world test."

6575
uv.lock generated Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,7 @@
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args
from .web.web_interface import get_web_interface_html, get_inline_ui_html from .web.web_interface import get_inline_ui_html, get_web_interface_html
__all__ = [ __all__ = [
"TranscriptionEngine", "TranscriptionEngine",

View File

@@ -1,99 +1,103 @@
import asyncio import asyncio
import numpy as np
from time import time, sleep
import math
import logging import logging
import traceback import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker from time import time
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from typing import Any, AsyncGenerator, List, Optional, Union
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.metrics_collector import SessionMetrics
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker SENTINEL = object() # unique sentinel object for end of stream marker
MIN_DURATION_REAL_SILENCE = 5
def cut_at(cumulative_pcm, cut_sec): async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
cumulative_len = 0 items: List[Any] = []
cut_sample = int(cut_sec * 16000)
for ind, pcm_array in enumerate(cumulative_pcm):
if (cumulative_len + len(pcm_array)) >= cut_sample:
cut_chunk = cut_sample - cumulative_len
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
return before, after
cumulative_len += len(pcm_array)
return np.concatenate(cumulative_pcm), []
async def get_all_from_queue(queue): first_item = await queue.get()
items = [] queue.task_done()
try: if first_item is SENTINEL:
while True: return first_item
item = queue.get_nowait() if isinstance(first_item, Silence):
items.append(item) return first_item
except asyncio.QueueEmpty: items.append(first_item)
pass
return items while True:
if not queue._queue:
break
next_item = queue._queue[0]
if next_item is SENTINEL:
break
if isinstance(next_item, Silence):
break
items.append(await queue.get())
queue.task_done()
if isinstance(items[0], np.ndarray):
return np.concatenate(items)
else: #translation
return items
class AudioProcessor: class AudioProcessor:
""" """
Processes audio streams for transcription and diarization. Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine'] models = kwargs['transcription_engine']
else: else:
models = TranscriptionEngine(**kwargs) models = TranscriptionEngine(**kwargs)
# Audio processing settings # Audio processing settings
self.args = models.args self.args = models.args
self.sample_rate = 16000 self.sample_rate = 16000
self.channels = 1 self.channels = 1
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size) chunk_seconds = self.args.vac_chunk_size if self.args.vac else self.args.min_chunk_size
self.samples_per_sec = int(self.sample_rate * chunk_seconds)
self.bytes_per_sample = 2 self.bytes_per_sample = 2
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
self.is_pcm_input = self.args.pcm_input self.is_pcm_input = self.args.pcm_input
# State management # State management
self.is_stopping = False self.is_stopping: bool = False
self.silence = False self.current_silence: Optional[Silence] = None
self.silence_duration = 0.0 self.state: State = State()
self.state = State() self.lock: asyncio.Lock = asyncio.Lock()
self.lock = asyncio.Lock() self.sep: str = " " # Default separator
self.sep = " " # Default separator self.last_response_content: FrontData = FrontData()
self.last_response_content = FrontData()
self.last_detected_speaker = None
self.speaker_languages = {}
self.diarization_before_transcription = False
self.segments = [] self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
self.beg_loop: Optional[float] = None
if self.diarization_before_transcription:
self.cumulative_pcm = []
self.last_start = 0.0
self.last_end = 0.0
# Models and processing # Models and processing
self.asr = models.asr self.asr: Any = models.asr
self.vac_model = models.vac_model self.vac: Optional[FixedVADIterator] = None
if self.args.vac: if self.args.vac:
self.vac = FixedVADIterator(models.vac_model) if models.vac_session is not None:
else: vac_model = OnnxWrapper(session=models.vac_session)
self.vac = None self.vac = FixedVADIterator(vac_model)
else:
self.ffmpeg_manager = None self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_reader_task = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self._ffmpeg_error = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None
if not self.is_pcm_input: if not self.is_pcm_input:
self.ffmpeg_manager = FFmpegManager( self.ffmpeg_manager = FFmpegManager(
@@ -104,63 +108,122 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}") logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray()
self.transcription_task = None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_task = None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_task = None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
self.watchdog_task = None self.pcm_buffer: bytearray = bytearray()
self.all_tasks_for_cleanup = [] self.total_pcm_samples: int = 0
self.transcription_task: Optional[asyncio.Task] = None
self.transcription = None self.diarization_task: Optional[asyncio.Task] = None
self.translation = None self.translation_task: Optional[asyncio.Task] = None
self.diarization = None self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.metrics: SessionMetrics = SessionMetrics()
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None
if self.args.transcription: if self.args.transcription:
self.transcription = online_factory(self.args, models.asr) self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep self.sep = self.transcription.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model: if models.translation_model:
self.translation = online_translation_factory(self.args, models.translation_model) self.translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer): async def _push_silence_event(self) -> None:
if self.transcription_queue:
await self.transcription_queue.put(self.current_silence)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(self.current_silence)
if self.translation_queue:
await self.translation_queue.put(self.current_silence)
async def _begin_silence(self, at_sample: Optional[int] = None) -> None:
if self.current_silence:
return
# Use audio stream time (sample-precise) for accurate silence duration
if at_sample is not None:
audio_t = at_sample / self.sample_rate
else:
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
self.current_silence = Silence(
is_starting=True, start=audio_t
)
# Push a separate start-only event so _end_silence won't mutate it
start_event = Silence(is_starting=True, start=audio_t)
if self.transcription_queue:
await self.transcription_queue.put(start_event)
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(start_event)
if self.translation_queue:
await self.translation_queue.put(start_event)
async def _end_silence(self, at_sample: Optional[int] = None) -> None:
if not self.current_silence:
return
if at_sample is not None:
audio_t = at_sample / self.sample_rate
else:
audio_t = self.total_pcm_samples / self.sample_rate if self.sample_rate else 0.0
self.current_silence.end = audio_t
self.current_silence.is_starting = False
self.current_silence.has_ended = True
self.current_silence.compute_duration()
self.metrics.n_silence_events += 1
if self.current_silence.duration is not None:
self.metrics.total_silence_duration_s += self.current_silence.duration
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state.new_tokens.append(self.current_silence)
# Push the completed silence as the end event (separate from the start event)
await self._push_silence_event()
self.current_silence = None
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
if pcm_chunk is None or pcm_chunk.size == 0:
return
if self.transcription_queue:
await self.transcription_queue.put(pcm_chunk.copy())
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(pcm_chunk.copy())
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
if silence_sample is None:
return None
relative_index = int(silence_sample - chunk_sample_start)
if relative_index <= 0:
return None
split_index = min(relative_index, len(pcm_array))
if split_index <= 0:
return None
return pcm_array[:split_index]
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def add_dummy_token(self): async def get_current_state(self) -> State:
"""Placeholder token when no transcription is available."""
async with self.lock:
current_time = time() - self.state.beg_loop
self.state.tokens.append(ASRToken(
start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True
))
async def get_current_state(self):
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
remaining_transcription = 0 remaining_transcription = 0
if self.state.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.state.tokens: if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization self.state.remaining_time_diarization = remaining_diarization
return self.state return self.state
async def ffmpeg_stdout_reader(self): async def ffmpeg_stdout_reader(self) -> None:
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline.""" """Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
beg = time() beg = time()
while True: while True:
@@ -203,50 +266,94 @@ class AudioProcessor:
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.") logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
if not self.diarization_before_transcription and self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
if self.diarization: if self.diarization:
await self.diarization_queue.put(SENTINEL) await self.diarization_queue.put(SENTINEL)
if self.translation: if self.translation:
await self.translation_queue.put(SENTINEL) await self.translation_queue.put(SENTINEL)
async def transcription_processor(self): async def _finish_transcription(self) -> None:
"""Call finish() on the online processor to flush remaining tokens."""
if not self.transcription:
return
try:
if hasattr(self.transcription, 'finish'):
final_tokens, end_time = await asyncio.to_thread(self.transcription.finish)
else:
# SimulStreamingOnlineProcessor uses start_silence() → process_iter(is_last=True)
final_tokens, end_time = await asyncio.to_thread(self.transcription.start_silence)
final_tokens = final_tokens or []
if final_tokens:
logger.info(f"Finish flushed {len(final_tokens)} tokens")
_buffer_transcript = self.transcription.get_buffer()
async with self.lock:
self.state.tokens.extend(final_tokens)
self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(self.state.end_buffer, end_time)
self.state.new_tokens.extend(final_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue:
for token in final_tokens:
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Error finishing transcription: {e}")
logger.debug(f"Traceback: {traceback.format_exc()}")
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
item = await self.transcription_queue.get() # item = await self.transcription_queue.get()
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.") logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done() await self._finish_transcription()
break break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer) transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
if type(item) is Silence: stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
asr_processing_logs += f" + Silence of = {item.duration:.2f}s" new_tokens = []
current_audio_processed_upto = self.state.end_buffer
if isinstance(item, Silence):
if item.is_starting:
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
self.transcription.start_silence
)
asr_processing_logs += f" + Silence starting"
if item.has_ended:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
cumulative_pcm_duration_stream_time += item.duration
current_audio_processed_upto = cumulative_pcm_duration_stream_time
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
if self.state.tokens: if self.state.tokens:
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |" asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
logger.info(asr_processing_logs) logger.info(asr_processing_logs)
cumulative_pcm_duration_stream_time += item.duration new_tokens = new_tokens or []
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0) current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
continue
elif isinstance(item, ChangeSpeaker): elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item) self.transcription.new_speaker(item)
continue
elif isinstance(item, np.ndarray): elif isinstance(item, np.ndarray):
pcm_array = item pcm_array = item
logger.info(asr_processing_logs)
logger.info(asr_processing_logs) cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
duration_this_chunk = len(pcm_array) / self.sample_rate self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
cumulative_pcm_duration_stream_time += duration_this_chunk _t0 = time()
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
_dur = time() - _t0
self.metrics.transcription_durations.append(_dur)
self.metrics.n_transcription_calls += 1
new_tokens = new_tokens or []
self.metrics.n_tokens_produced += len(new_tokens)
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
_buffer_transcript = self.transcription.get_buffer() _buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text buffer_text = _buffer_transcript.text
@@ -259,29 +366,28 @@ class AudioProcessor:
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None: if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end) candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto) candidate_end_times.append(current_audio_processed_upto)
async with self.lock: async with self.lock:
self.state.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
self.state.end_buffer = max(candidate_end_times) self.state.end_buffer = max(candidate_end_times)
self.state.new_tokens.extend(new_tokens)
self.state.new_tokens_buffer = _buffer_transcript
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
self.transcription_queue.task_done()
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done() self.transcription_queue.task_done()
if self.is_stopping: if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.") logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue: if self.diarization_queue:
@@ -292,124 +398,60 @@ class AudioProcessor:
logger.info("Transcription processor task finished.") logger.info("Transcription processor task finished.")
async def diarization_processor(self, diarization_obj): async def diarization_processor(self) -> None:
"""Process audio chunks for speaker diarization."""
if self.diarization_before_transcription:
self.current_speaker = 0
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0))
while True: while True:
try: try:
item = await self.diarization_queue.get() item = await get_all_from_queue(self.diarization_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
diarization_obj.insert_silence(item.duration) if item.has_ended:
self.diarization.insert_silence(item.duration)
continue continue
elif isinstance(item, np.ndarray): self.diarization.insert_audio_chunk(item)
pcm_array = item diarization_segments = await self.diarization.diarize()
else: diar_end = 0.0
raise Exception('item should be pcm_array') if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
# Process diarization self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
await diarization_obj.diarize(pcm_array)
if self.diarization_before_transcription:
segments = diarization_obj.get_segments()
self.cumulative_pcm.append(pcm_array)
if segments:
last_segment = segments[-1]
if last_segment.speaker != self.current_speaker:
cut_sec = last_segment.start - self.last_end
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.current_speaker = last_segment.speaker
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start))
cut_sec = last_segment.end - last_segment.start
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.last_start = last_segment.start
self.last_end = last_segment.end
else:
cut_sec = last_segment.end - self.last_end
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.last_end = last_segment.end
elif not self.diarization_before_transcription:
async with self.lock:
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
self.state.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.state.tokens) > 0:
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
self.diarization_queue.task_done()
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
self.diarization_queue.task_done()
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self): async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
try: try:
item = await self.translation_queue.get() #block until at least 1 token item = await get_all_from_queue(self.translation_queue)
if item is SENTINEL: if item is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.") logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
self.translation.insert_silence(item.duration) if item.is_starting:
continue new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
if item.has_ended:
# get all the available tokens for translation. The more words, the more precise self.translation.insert_silence(item.duration)
tokens_to_process = [item]
additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False
for additional_token in additional_tokens:
if additional_token is SENTINEL:
sentinel_found = True
break
elif type(additional_token) is Silence:
self.translation.insert_silence(additional_token.duration)
continue continue
else: elif isinstance(item, ChangeSpeaker):
tokens_to_process.append(additional_token) new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
if tokens_to_process: pass
self.translation.insert_tokens(tokens_to_process) else:
translation_validated_segments, translation_buffer = await asyncio.to_thread(self.translation.process) self.translation.insert_tokens(item)
async with self.lock: new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
self.state.translation_validated_segments = translation_validated_segments async with self.lock:
self.state.translation_buffer = translation_buffer self.state.new_translation.append(new_translation)
self.translation_queue.task_done() self.state.new_translation_buffer = new_translation_buffer
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e: except Exception as e:
logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and item is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.") logger.info("Translation processor task finished.")
async def results_formatter(self): async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output.""" """Format processing results for output."""
while True: while True:
try: try:
@@ -419,72 +461,57 @@ class AudioProcessor:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
state = await self.get_current_state() self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
lines, undiarized_text = format_output( diarization=self.args.diarization,
state, translation=bool(self.translation),
self.silence, current_silence=self.current_silence
args = self.args,
sep=self.sep
) )
if lines and lines[-1].speaker == -2: state = await self.get_current_state()
buffer_transcription = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = '' buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
if undiarized_text:
buffer_diarization = self.sep.join(undiarized_text)
async with self.lock:
self.state.end_attributed_speaker = state.end_attributed_speaker
response_status = "active_transcription" response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization: if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected" response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
response = FrontData( response = FrontData(
status=response_status, status=response_status,
lines=lines, lines=lines,
buffer_transcription=buffer_transcription.text.strip(), buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization, buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): if should_push:
self.metrics.n_responses_sent += 1
yield response yield response
self.last_response_content = response self.last_response_content = response
if self.is_stopping and self.transcription_task and self.transcription_task.done() and self.diarization_task and self.diarization_task.done(): if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return return
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
except Exception as e: except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self): async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
processing_tasks_for_watchdog = [] processing_tasks_for_watchdog: List[asyncio.Task] = []
# If using FFmpeg (non-PCM input), start it and spawn stdout reader # If using FFmpeg (non-PCM input), start it and spawn stdout reader
if not self.is_pcm_input: if not self.is_pcm_input:
success = await self.ffmpeg_manager.start() success = await self.ffmpeg_manager.start()
if not success: if not success:
logger.error("Failed to start FFmpeg manager") logger.error("Failed to start FFmpeg manager")
async def error_generator(): async def error_generator() -> AsyncGenerator[FrontData, None]:
yield FrontData( yield FrontData(
status="error", status="error",
error="FFmpeg failed to start. Please check that FFmpeg is installed." error="FFmpeg failed to start. Please check that FFmpeg is installed."
@@ -498,30 +525,35 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor()) self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task) self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization: if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization)) self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation: if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor()) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health # Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task) self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor): async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
"""Monitors the health of critical processing tasks.""" """Monitors the health of critical processing tasks."""
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
while True: while True:
try: try:
if not tasks_remaining:
logger.info("Watchdog task finishing: all monitored tasks completed.")
return
await asyncio.sleep(10) await asyncio.sleep(10)
for i, task in enumerate(tasks_to_monitor): for i, task in enumerate(list(tasks_remaining)):
if task.done(): if task.done():
exc = task.exception() exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}" task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
@@ -529,21 +561,22 @@ class AudioProcessor:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}") logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else: else:
logger.info(f"{task_name} completed normally.") logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Watchdog task cancelled.") logger.info("Watchdog task cancelled.")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self): async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
self.is_stopping = True self.is_stopping = True
for task in self.all_tasks_for_cleanup: for task in self.all_tasks_for_cleanup:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t] created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks: if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True) await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -557,19 +590,40 @@ class AudioProcessor:
logger.warning(f"Error stopping FFmpeg manager: {e}") logger.warning(f"Error stopping FFmpeg manager: {e}")
if self.diarization: if self.diarization:
self.diarization.close() self.diarization.close()
# Finalize session metrics
self.metrics.total_audio_duration_s = self.total_pcm_samples / self.sample_rate
self.metrics.log_summary()
logger.info("AudioProcessor cleanup complete.") logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self) -> bool:
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
self.diarization_task,
self.translation_task,
self.ffmpeg_reader_task,
]
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message):
async def process_audio(self, message: Optional[bytes]) -> None:
"""Process incoming audio data.""" """Process incoming audio data."""
if not self.state.beg_loop: if not self.beg_loop:
self.state.beg_loop = time() self.beg_loop = time()
self.metrics.session_start = self.beg_loop
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True self.is_stopping = True
# Flush any remaining PCM data before signaling end-of-stream
if self.is_pcm_input and self.pcm_buffer:
await self._flush_remaining_pcm()
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
@@ -582,6 +636,8 @@ class AudioProcessor:
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.") logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
return return
self.metrics.n_chunks_received += 1
if self.is_pcm_input: if self.is_pcm_input:
self.pcm_buffer.extend(message) self.pcm_buffer.extend(message)
await self.handle_pcm_data() await self.handle_pcm_data()
@@ -597,7 +653,12 @@ class AudioProcessor:
else: else:
logger.warning("Failed to write audio data to FFmpeg") logger.warning("Failed to write audio data to FFmpeg")
async def handle_pcm_data(self): async def handle_pcm_data(self) -> None:
# Without VAC, there's no speech detector to end the initial silence.
# Clear it on the first audio chunk so audio actually gets enqueued.
if not self.args.vac and self.current_silence:
await self._end_silence()
# Process when enough data # Process when enough data
if len(self.pcm_buffer) < self.bytes_per_sec: if len(self.pcm_buffer) < self.bytes_per_sec:
return return
@@ -610,46 +671,54 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec) chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0: if aligned_chunk_size == 0:
return return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size]) pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:] self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
res = None num_samples = len(pcm_array)
end_of_audio = False chunk_sample_start = self.total_pcm_samples
silence_buffer = None chunk_sample_end = chunk_sample_start + num_samples
res = None
if self.args.vac: if self.args.vac:
res = self.vac(pcm_array) res = self.vac(pcm_array)
if res is not None: if res is not None:
if res.get("end", 0) > res.get("start", 0): if "start" in res and self.current_silence:
end_of_audio = True await self._end_silence(at_sample=res.get("start"))
elif self.silence: #end of silence
self.silence = False
silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer: if "end" in res and not self.current_silence:
if not self.diarization_before_transcription and self.transcription_queue: pre_silence_chunk = self._slice_before_silence(
await self.transcription_queue.put(silence_buffer) pcm_array, chunk_sample_start, res.get("end")
if self.args.diarization and self.diarization_queue: )
await self.diarization_queue.put(silence_buffer) if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
if self.translation_queue: await self._enqueue_active_audio(pre_silence_chunk)
await self.translation_queue.put(silence_buffer) await self._begin_silence(at_sample=res.get("end"))
if not self.silence: if not self.current_silence:
if not self.diarization_before_transcription and self.transcription_queue: await self._enqueue_active_audio(pcm_array)
await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue: self.total_pcm_samples = chunk_sample_end
await self.diarization_queue.put(pcm_array.copy())
self.silence_duration = 0.0
if end_of_audio:
self.silence = True
self.start_silence = time()
if not self.args.transcription and not self.args.diarization: if not self.args.transcription and not self.args.diarization:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def _flush_remaining_pcm(self) -> None:
"""Flush whatever PCM data remains in the buffer, regardless of size threshold."""
if not self.pcm_buffer:
return
aligned_size = (len(self.pcm_buffer) // self.bytes_per_sample) * self.bytes_per_sample
if aligned_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_size])
self.pcm_buffer = self.pcm_buffer[aligned_size:]
# End any active silence so the audio gets enqueued
if self.current_silence:
await self._end_silence(at_sample=self.total_pcm_samples)
await self._enqueue_active_audio(pcm_array)
self.total_pcm_samples += len(pcm_array)
logger.info(f"Flushed remaining PCM buffer: {len(pcm_array)} samples ({len(pcm_array)/self.sample_rate:.2f}s)")

View File

@@ -0,0 +1,47 @@
import importlib.util
import logging
import platform
logger = logging.getLogger(__name__)
def module_available(module_name):
"""Return True if the given module can be imported."""
return importlib.util.find_spec(module_name) is not None
def mlx_backend_available(warn_on_missing = False):
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
available = (
is_macos
and is_arm
and module_available("mlx_whisper")
)
if not available and warn_on_missing and is_macos and is_arm:
logger.warning(
"=" * 50
+ "\nMLX Whisper not found but you are on Apple Silicon. "
"Consider installing mlx-whisper for better performance: "
"`pip install mlx-whisper`\n"
+ "=" * 50
)
return available
def voxtral_hf_backend_available():
"""Return True if HF Transformers Voxtral backend is available."""
return module_available("transformers")
def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin":
logger.warning(
"=" * 50
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
"for better performance: `pip install faster-whisper`\n"
+ "=" * 50
)
return available

View File

@@ -1,25 +1,26 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
get_inline_ui_html, parse_args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
args = parse_args() config = parse_args()
transcription_engine = None transcription_engine = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global transcription_engine global transcription_engine
transcription_engine = TranscriptionEngine( transcription_engine = TranscriptionEngine(config=config)
**vars(args),
)
yield yield
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@@ -60,7 +61,7 @@ async def websocket_endpoint(websocket: WebSocket):
logger.info("WebSocket connection opened.") logger.info("WebSocket connection opened.")
try: try:
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)}) await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)})
except Exception as e: except Exception as e:
logger.warning(f"Failed to send config to client: {e}") logger.warning(f"Failed to send config to client: {e}")
@@ -100,26 +101,26 @@ def main():
uvicorn_kwargs = { uvicorn_kwargs = {
"app": "whisperlivekit.basic_server:app", "app": "whisperlivekit.basic_server:app",
"host":args.host, "host": config.host,
"port":args.port, "port": config.port,
"reload": False, "reload": False,
"log_level": "info", "log_level": "info",
"lifespan": "on", "lifespan": "on",
} }
ssl_kwargs = {} ssl_kwargs = {}
if args.ssl_certfile or args.ssl_keyfile: if config.ssl_certfile or config.ssl_keyfile:
if not (args.ssl_certfile and args.ssl_keyfile): if not (config.ssl_certfile and config.ssl_keyfile):
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.") raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
ssl_kwargs = { ssl_kwargs = {
"ssl_certfile": args.ssl_certfile, "ssl_certfile": config.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile "ssl_keyfile": config.ssl_keyfile,
} }
if ssl_kwargs: if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs} uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
if args.forwarded_allow_ips: if config.forwarded_allow_ips:
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips } uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips}
uvicorn.run(**uvicorn_kwargs) uvicorn.run(**uvicorn_kwargs)

102
whisperlivekit/config.py Normal file
View File

@@ -0,0 +1,102 @@
"""Typed configuration for the WhisperLiveKit pipeline."""
import logging
from dataclasses import dataclass, field, fields
from typing import Optional
logger = logging.getLogger(__name__)
@dataclass
class WhisperLiveKitConfig:
"""Single source of truth for all WhisperLiveKit configuration.
Replaces the previous dict-based parameter system in TranscriptionEngine.
All fields have defaults matching the prior behaviour.
"""
# Server / global
host: str = "localhost"
port: int = 8000
diarization: bool = False
punctuation_split: bool = False
target_language: str = ""
vac: bool = True
vac_chunk_size: float = 0.04
log_level: str = "DEBUG"
ssl_certfile: Optional[str] = None
ssl_keyfile: Optional[str] = None
forwarded_allow_ips: Optional[str] = None
transcription: bool = True
vad: bool = True
pcm_input: bool = False
disable_punctuation_split: bool = False
diarization_backend: str = "sortformer"
backend_policy: str = "simulstreaming"
backend: str = "auto"
# Transcription common
warmup_file: Optional[str] = None
min_chunk_size: float = 0.1
model_size: str = "base"
model_cache_dir: Optional[str] = None
model_dir: Optional[str] = None
model_path: Optional[str] = None
lora_path: Optional[str] = None
lan: str = "auto"
direct_english_translation: bool = False
# LocalAgreement-specific
buffer_trimming: str = "segment"
confidence_validation: bool = False
buffer_trimming_sec: float = 15.0
# SimulStreaming-specific
disable_fast_encoder: bool = False
custom_alignment_heads: Optional[str] = None
frame_threshold: int = 25
beams: int = 1
decoder_type: Optional[str] = None
audio_max_len: float = 20.0
audio_min_len: float = 0.0
cif_ckpt_path: Optional[str] = None
never_fire: bool = False
init_prompt: Optional[str] = None
static_init_prompt: Optional[str] = None
max_context_tokens: Optional[int] = None
# Diarization (diart)
segmentation_model: str = "pyannote/segmentation-3.0"
embedding_model: str = "pyannote/embedding"
# Translation
nllb_backend: str = "transformers"
nllb_size: str = "600M"
def __post_init__(self):
# .en model suffix forces English
if self.model_size and self.model_size.endswith(".en"):
self.lan = "en"
# Normalize backend_policy aliases
if self.backend_policy == "1":
self.backend_policy = "simulstreaming"
elif self.backend_policy == "2":
self.backend_policy = "localagreement"
# ------------------------------------------------------------------
# Factory helpers
# ------------------------------------------------------------------
@classmethod
def from_namespace(cls, ns) -> "WhisperLiveKitConfig":
"""Create config from an argparse Namespace, ignoring unknown keys."""
known = {f.name for f in fields(cls)}
return cls(**{k: v for k, v in vars(ns).items() if k in known})
@classmethod
def from_kwargs(cls, **kwargs) -> "WhisperLiveKitConfig":
"""Create config from keyword arguments; warns on unknown keys."""
known = {f.name for f in fields(cls)}
unknown = set(kwargs.keys()) - known
if unknown:
logger.warning("Unknown config keys ignored: %s", unknown)
return cls(**{k: v for k, v in kwargs.items() if k in known})

View File

@@ -1,176 +1,201 @@
try: import logging
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
except ImportError:
from .whisper_streaming_custom.whisper_online import backend_factory
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
from argparse import Namespace
import sys import sys
import threading
from argparse import Namespace
from dataclasses import asdict
def update_with_kwargs(_dict, kwargs): from whisperlivekit.config import WhisperLiveKitConfig
_dict.update({ from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
k: v for k, v in kwargs.items() if k in _dict from whisperlivekit.local_agreement.whisper_online import backend_factory
}) from whisperlivekit.simul_whisper import SimulStreamingASR
return _dict
logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, config=None, **kwargs):
if TranscriptionEngine._initialized: # Thread-safe initialization check
return with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
global_params = { try:
"host": "localhost", self._do_init(config, **kwargs)
"port": 8000, except Exception:
"diarization": False, # Reset singleton so a retry is possible
"punctuation_split": False, with TranscriptionEngine._lock:
"target_language": "", TranscriptionEngine._instance = None
"vac": True, TranscriptionEngine._initialized = False
"vac_onnx": False, raise
"vac_chunk_size": 0.04,
"log_level": "DEBUG",
"ssl_certfile": None,
"ssl_keyfile": None,
"forwarded_allow_ips": None,
"transcription": True,
"vad": True,
"pcm_input": False,
"disable_punctuation_split" : False,
"diarization_backend": "sortformer",
}
global_params = update_with_kwargs(global_params, kwargs)
transcription_common_params = { with TranscriptionEngine._lock:
"backend": "simulstreaming", TranscriptionEngine._initialized = True
"warmup_file": None,
"min_chunk_size": 0.5,
"model_size": "tiny",
"model_cache_dir": None,
"model_dir": None,
"lan": "auto",
"task": "transcribe",
}
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
if transcription_common_params['model_size'].endswith(".en"): def _do_init(self, config=None, **kwargs):
transcription_common_params["lan"] = "en" # Handle negated kwargs from programmatic API
if 'no_transcription' in kwargs: if 'no_transcription' in kwargs:
global_params['transcription'] = not global_params['no_transcription'] kwargs['transcription'] = not kwargs.pop('no_transcription')
if 'no_vad' in kwargs: if 'no_vad' in kwargs:
global_params['vad'] = not kwargs['no_vad'] kwargs['vad'] = not kwargs.pop('no_vad')
if 'no_vac' in kwargs: if 'no_vac' in kwargs:
global_params['vac'] = not kwargs['no_vac'] kwargs['vac'] = not kwargs.pop('no_vac')
if config is None:
if isinstance(kwargs.get('config'), WhisperLiveKitConfig):
config = kwargs.pop('config')
else:
config = WhisperLiveKitConfig.from_kwargs(**kwargs)
self.config = config
# Backward compat: expose as self.args (Namespace-like) for AudioProcessor etc.
self.args = Namespace(**asdict(config))
self.args = Namespace(**{**global_params, **transcription_common_params})
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
self.vac_model = None self.vac_session = None
if self.args.vac: if config.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad from whisperlivekit.silero_vad_iterator import is_onnx_available
# Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False) if is_onnx_available():
self.vac_model = load_silero_vad(onnx=use_onnx) from whisperlivekit.silero_vad_iterator import load_onnx_session
self.vac_session = load_onnx_session()
if self.args.transcription:
if self.args.backend == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingASR
simulstreaming_params = {
"disable_fast_encoder": False,
"custom_alignment_heads": None,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
"audio_max_len": 20.0,
"audio_min_len": 0.0,
"cif_ckpt_path": None,
"never_fire": False,
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"preload_model_count": 1,
}
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
self.tokenizer = None
self.asr = SimulStreamingASR(
**transcription_common_params, **simulstreaming_params
)
else: else:
logger.warning(
whisperstreaming_params = { "onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"buffer_trimming": "segment", "For multi-user scenarios, install onnxruntime: pip install onnxruntime"
"confidence_validation": False,
"buffer_trimming_sec": 15,
}
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
self.asr = backend_factory(
**transcription_common_params, **whisperstreaming_params
) )
if self.args.diarization: transcription_common_params = {
if self.args.diarization_backend == "diart": "warmup_file": config.warmup_file,
from whisperlivekit.diarization.diart_backend import DiartDiarization "min_chunk_size": config.min_chunk_size,
diart_params = { "model_size": config.model_size,
"segmentation_model": "pyannote/segmentation-3.0", "model_cache_dir": config.model_cache_dir,
"embedding_model": "pyannote/embedding", "model_dir": config.model_dir,
"model_path": config.model_path,
"lora_path": config.lora_path,
"lan": config.lan,
"direct_english_translation": config.direct_english_translation,
}
if config.transcription:
if config.backend == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
self.tokenizer = None
self.asr = VoxtralMLXASR(**transcription_common_params)
logger.info("Using Voxtral MLX native backend")
elif config.backend == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR
self.tokenizer = None
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
logger.info("Using Voxtral HF Transformers streaming backend")
elif config.backend_policy == "simulstreaming":
simulstreaming_params = {
"disable_fast_encoder": config.disable_fast_encoder,
"custom_alignment_heads": config.custom_alignment_heads,
"frame_threshold": config.frame_threshold,
"beams": config.beams,
"decoder_type": config.decoder_type,
"audio_max_len": config.audio_max_len,
"audio_min_len": config.audio_min_len,
"cif_ckpt_path": config.cif_ckpt_path,
"never_fire": config.never_fire,
"init_prompt": config.init_prompt,
"static_init_prompt": config.static_init_prompt,
"max_context_tokens": config.max_context_tokens,
} }
diart_params = update_with_kwargs(diart_params, kwargs)
self.diarization_model = DiartDiarization( self.tokenizer = None
block_duration=self.args.min_chunk_size, self.asr = SimulStreamingASR(
**diart_params **transcription_common_params,
**simulstreaming_params,
backend=config.backend,
) )
elif self.args.diarization_backend == "sortformer": logger.info(
"Using SimulStreaming policy with %s backend",
getattr(self.asr, "encoder_backend", "whisper"),
)
else:
whisperstreaming_params = {
"buffer_trimming": config.buffer_trimming,
"confidence_validation": config.confidence_validation,
"buffer_trimming_sec": config.buffer_trimming_sec,
}
self.asr = backend_factory(
backend=config.backend,
**transcription_common_params,
**whisperstreaming_params,
)
logger.info(
"Using LocalAgreement policy with %s backend",
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
)
if config.diarization:
if config.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization
self.diarization_model = DiartDiarization(
block_duration=config.min_chunk_size,
segmentation_model=config.segmentation_model,
embedding_model=config.embedding_model,
)
elif config.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization() self.diarization_model = SortformerDiarization()
self.translation_model = None self.translation_model = None
if self.args.target_language: if config.target_language:
if self.args.lan == 'auto' and self.args.backend != "simulstreaming": if config.lan == 'auto' and config.backend_policy != "simulstreaming":
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') raise ValueError('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else: else:
try: try:
from nllw import load_model from nllw import load_model
except: except ImportError:
raise Exception('To use translation, you must install nllw: `pip install nllw`') raise ImportError('To use translation, you must install nllw: `pip install nllw`')
translation_params = { self.translation_model = load_model(
"nllb_backend": "transformers", [config.lan],
"nllb_size": "600M" nllb_backend=config.nllb_backend,
} nllb_size=config.nllb_size,
translation_params = update_with_kwargs(translation_params, kwargs) )
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr): def online_factory(args, asr):
if args.backend == "simulstreaming": if getattr(args, 'backend', None) == "voxtral-mlx":
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
return VoxtralMLXOnlineProcessor(asr)
if getattr(args, 'backend', None) == "voxtral":
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
return VoxtralHFStreamingOnlineProcessor(asr)
if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(asr) return SimulStreamingOnlineProcessor(asr)
else: return OnlineASRProcessor(asr)
online = OnlineASRProcessor(asr)
return online
def online_diarization_factory(args, diarization_backend): def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart": if args.diarization_backend == "diart":
online = diarization_backend online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
elif args.diarization_backend == "sortformer":
if args.diarization_backend == "sortformer": from whisperlivekit.diarization.sortformer_backend import \
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend) online = SortformerDiarizationOnline(shared_model=diarization_backend)
else:
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")
return online return online

View File

@@ -1,32 +1,28 @@
import asyncio import asyncio
import re
import threading
import numpy as np
import logging import logging
import threading
import time import time
from queue import SimpleQueue, Empty from queue import Empty, SimpleQueue
from typing import Any, List, Tuple
import diart.models as m
import numpy as np
from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference from diart.inference import StreamingInference
from diart.sources import AudioSource from diart.sources import AudioSource, MicrophoneAudioSource
from whisperlivekit.timed_objects import SpeakerSegment
from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation from pyannote.core import Annotation
import diart.models as m from rx.core import Observer
from whisperlivekit.diarization.utils import extract_number
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def extract_number(s: str) -> int:
m = re.search(r'\d+', s)
return int(m.group()) if m else None
class DiarizationObserver(Observer): class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" """Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self): def __init__(self):
self.speaker_segments = [] self.diarization_segments = []
self.processed_time = 0 self.processed_time = 0
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
@@ -48,7 +44,7 @@ class DiarizationObserver(Observer):
for speaker, label in annotation._labels.items(): for speaker, label in annotation._labels.items():
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
print(f" {speaker}: {start:.2f}s-{end:.2f}s") print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.speaker_segments.append(SpeakerSegment( self.diarization_segments.append(SpeakerSegment(
speaker=speaker, speaker=speaker,
start=start + self.global_time_offset, start=start + self.global_time_offset,
end=end + self.global_time_offset end=end + self.global_time_offset
@@ -59,14 +55,14 @@ class DiarizationObserver(Observer):
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
return self.speaker_segments.copy() return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0): def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time.""" """Clear segments older than the specified time."""
with self.segment_lock: with self.segment_lock:
current_time = self.processed_time current_time = self.processed_time
self.speaker_segments = [ self.diarization_segments = [
segment for segment in self.speaker_segments segment for segment in self.diarization_segments
if current_time - segment.end < older_than if current_time - segment.end < older_than
] ]
@@ -178,7 +174,6 @@ class DiartDiarization:
self.pipeline = SpeakerDiarization(config=config) self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver() self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone: if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration) self.source = MicrophoneAudioSource(block_duration=block_duration)
@@ -203,46 +198,20 @@ class DiartDiarization:
def insert_silence(self, silence_duration): def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
""" """Buffer audio for the next diarization step."""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""
if self.custom_source: if self.custom_source:
self.custom_source.close() self.custom_source.close()
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
if not use_punctuation_split:
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
else:
tokens = add_speaker_to_tokens(segments, tokens)
return tokens
def concatenate_speakers(segments): def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]

View File

@@ -1,11 +1,12 @@
import numpy as np
import torch
import logging import logging
import threading import threading
import time import time
import wave import wave
from queue import Empty, SimpleQueue
from typing import List, Optional from typing import List, Optional
from queue import SimpleQueue, Empty
import numpy as np
import torch
from whisperlivekit.timed_objects import SpeakerSegment from whisperlivekit.timed_objects import SpeakerSegment
@@ -94,11 +95,11 @@ class SortformerDiarizationOnline:
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
""" """
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.speaker_segments = [] self.diarization_segments = []
self.diar_segments = []
self.buffer_audio = np.array([], dtype=np.float32) self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock() self.segment_lock = threading.Lock()
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.processed_time = 0.0
self.debug = False self.debug = False
self.diar_model = shared_model.diar_model self.diar_model = shared_model.diar_model
@@ -155,12 +156,10 @@ class SortformerDiarizationOnline:
) )
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device) self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float): def insert_silence(self, silence_duration: Optional[float]):
""" """
Insert silence period by adjusting the global time offset. Insert silence period by adjusting the global time offset.
@@ -171,248 +170,111 @@ class SortformerDiarizationOnline:
self.global_time_offset += silence_duration self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s") logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
if self.debug:
self.audio_buffer.append(pcm_array.copy())
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
async def diarize(self):
""" """
Process audio data for diarization in streaming fashion. Process audio data for diarization in streaming fashion.
Args: Args:
pcm_array: Audio data as numpy array pcm_array: Audio data as numpy array
""" """
try:
if self.debug:
self.audio_buffer.append(pcm_array.copy())
threshold = int(self.chunk_duration_seconds * self.sample_rate) threshold = int(self.chunk_duration_seconds * self.sample_rate)
if not len(self.buffer_audio) >= threshold:
return []
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
if not len(self.buffer_audio) >= threshold: processed_signal=chunk_feat_seq_t,
return processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
audio = self.buffer_audio[:threshold] total_preds=self.total_preds,
self.buffer_audio = self.buffer_audio[threshold:] left_offset=left_offset,
right_offset=right_offset,
device = self.diar_model.device )
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) new_segments = self._process_predictions()
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
self._chunk_index += 1
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( return new_segments
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
def _process_predictions(self): def _process_predictions(self):
"""Process model predictions and convert to speaker segments.""" """Process model predictions and convert to speaker segments."""
try: preds_np = self.total_preds[0].cpu().numpy()
preds_np = self.total_preds[0].cpu().numpy() active_speakers = np.argmax(preds_np, axis=1)
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers)
# Get predictions for current chunk
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock:
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
# Check if this continues the last segment or starts a new one
if (self.speaker_segments and
self.speaker_segments[-1].speaker == spk and
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
# Continue existing segment
self.speaker_segments[-1].end = end_time
else:
# Create new segment
self.speaker_segments.append(SpeakerSegment(
speaker=spk,
start=start_time,
end=end_time
))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args: if self._len_prediction is None:
tokens: List of tokens with timing information self._len_prediction = len(active_speakers) #12
use_punctuation_split: Whether to use punctuation for boundary refinement
frame_duration = self.chunk_duration_seconds / self._len_prediction
Returns: current_chunk_preds = active_speakers[-self._len_prediction:]
List of tokens with speaker assignments
Last speaker_segment new_segments = []
"""
with self.segment_lock: with self.segment_lock:
segments = self.speaker_segments.copy() base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
current_spk = current_chunk_preds[0]
if not segments or not tokens: start_time = round(base_time, 2)
logger.debug("No segments or tokens available for speaker assignment") for idx, spk in enumerate(current_chunk_preds):
return tokens current_time = round(base_time + idx * frame_duration, 2)
if spk != current_spk:
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") new_segments.append(SpeakerSegment(
use_punctuation_split = False speaker=current_spk,
if not use_punctuation_split: start=start_time,
# Simple overlap-based assignment end=current_time
for token in tokens: ))
token.speaker = -1 # Default to no speaker start_time = current_time
for segment in segments: current_spk = spk
# Check for timing overlap new_segments.append(
if not (segment.end <= token.start or segment.start >= token.end): SpeakerSegment(
token.speaker = segment.speaker + 1 # Convert to 1-based indexing speaker=current_spk,
break start=start_time,
else: end=current_time
# Use punctuation-aware assignment (similar to diart_backend) )
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens) )
return new_segments
return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]: def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments.""" """Get a copy of the current speaker segments."""
with self.segment_lock: with self.segment_lock:
return self.speaker_segments.copy() return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self): def close(self):
"""Close the diarization system and clean up resources.""" """Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization") logger.info("Closing SortformerDiarization")
with self.segment_lock: with self.segment_lock:
self.speaker_segments.clear() self.diarization_segments.clear()
if self.debug: if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer) concatenated_audio = np.concatenate(self.audio_buffer)
@@ -425,20 +287,17 @@ class SortformerDiarizationOnline:
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav") logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int: from whisperlivekit.diarization.utils import extract_number
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__': if __name__ == '__main__':
import asyncio import asyncio
import librosa import librosa
async def main(): async def main():
"""TEST ONLY.""" """TEST ONLY."""
an4_audio = 'audio_test.mp3' an4_audio = 'diarization_audio.wav'
signal, sr = librosa.load(an4_audio, sr=16000) signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30] signal = signal[:16000*30]
@@ -450,13 +309,15 @@ if __name__ == '__main__':
print("Speaker 0: 0:25 - 0:30") print("Speaker 0: 0:25 - 0:30")
print("=" * 50) print("=" * 50)
diarization = SortformerDiarization(sample_rate=16000) diarization_backend = SortformerDiarization()
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
chunk_size = 1600 chunk_size = 1600
for i in range(0, len(signal), chunk_size): for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size] chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk) new_segments = await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}") print(f"Processed chunk {i // chunk_size + 1}")
print(new_segments)
segments = diarization.get_segments() segments = diarization.get_segments()
print("\nDiarization results:") print("\nDiarization results:")

View File

@@ -1,205 +0,0 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)

View File

@@ -0,0 +1,7 @@
import re
def extract_number(s: str) -> int:
"""Extract the first integer from a string, e.g. 'speaker_2' -> 2."""
m = re.search(r'\d+', s)
return int(m.group()) if m else 0

View File

@@ -1,8 +1,8 @@
import asyncio import asyncio
import contextlib
import logging import logging
from enum import Enum from enum import Enum
from typing import Optional, Callable from typing import Callable, Optional
import contextlib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@@ -1,32 +1,31 @@
import sys
import logging
import io import io
import soundfile as sf import logging
import math import math
import sys
from typing import List from typing import List
import numpy as np import numpy as np
import soundfile as sf
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ASRBase: class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped, sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed) # "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr): def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
self.lora_path = lora_path
if lan == "auto": if lan == "auto":
self.original_language = None self.original_language = None
else: else:
self.original_language = lan self.original_language = lan
self.model = self.load_model(model_size, cache_dir, model_dir) self.model = self.load_model(model_size, cache_dir, model_dir)
def with_offset(self, offset: float) -> ASRToken:
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
return ASRToken(self.start + offset, self.end + offset, self.text)
def __repr__(self):
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
def load_model(self, model_size, cache_dir, model_dir): def load_model(self, model_size, cache_dir, model_dir):
raise NotImplementedError("must be implemented in the child class") raise NotImplementedError("must be implemented in the child class")
@@ -37,40 +36,59 @@ class ASRBase:
raise NotImplementedError("must be implemented in the child class") raise NotImplementedError("must be implemented in the child class")
class WhisperTimestampedASR(ASRBase): class WhisperASR(ASRBase):
"""Uses whisper_timestamped as the backend.""" """Uses WhisperLiveKit's built-in Whisper implementation."""
sep = " " sep = " "
def load_model(self, model_size=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
import whisper from whisperlivekit.whisper import load_model as load_whisper_model
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None: if model_dir is not None:
logger.debug("ignoring model_dir, not implemented") resolved_path = resolve_model_path(model_dir)
return whisper.load_model(model_size, download_root=cache_dir) if resolved_path.is_dir():
model_info = detect_model_format(resolved_path)
if not model_info.has_pytorch:
raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}"
)
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
if model_size is None:
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
def transcribe(self, audio, init_prompt=""): def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped( options = dict(self.transcribe_kargs)
options.pop("vad", None)
options.pop("vad_filter", None)
language = self.original_language if self.original_language else None
result = whisper_transcribe(
self.model, self.model,
audio, audio,
language=self.original_language, language=language,
initial_prompt=init_prompt, initial_prompt=init_prompt,
verbose=None,
condition_on_previous_text=True, condition_on_previous_text=True,
**self.transcribe_kargs, word_timestamps=True,
**options,
) )
return result return result
def ts_words(self, r) -> List[ASRToken]: def ts_words(self, r) -> List[ASRToken]:
""" """
Converts the whisper_timestamped result to a list of ASRToken objects. Converts the Whisper result to a list of ASRToken objects.
""" """
tokens = [] tokens = []
for segment in r["segments"]: for segment in r["segments"]:
for word in segment["words"]: for word in segment["words"]:
token = ASRToken(word["start"], word["end"], word["text"]) token = ASRToken(
word["start"],
word["end"],
word["word"],
probability=word.get("probability"),
)
tokens.append(token) tokens.append(token)
return tokens return tokens
@@ -78,11 +96,7 @@ class WhisperTimestampedASR(ASRBase):
return [segment["end"] for segment in res["segments"]] return [segment["end"] for segment in res["segments"]]
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad"] = True logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase): class FasterWhisperASR(ASRBase):
"""Uses faster-whisper as the backend.""" """Uses faster-whisper as the backend."""
@@ -92,9 +106,10 @@ class FasterWhisperASR(ASRBase):
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
if model_dir is not None: if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. " resolved_path = resolve_model_path(model_dir)
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
f"model_size and cache_dir parameters are not used.") f"model_size and cache_dir parameters are not used.")
model_size_or_path = model_dir model_size_or_path = str(resolved_path)
elif model_size is not None: elif model_size is not None:
model_size_or_path = model_size model_size_or_path = model_size
else: else:
@@ -139,10 +154,6 @@ class FasterWhisperASR(ASRBase):
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad_filter"] = True self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class MLXWhisper(ASRBase): class MLXWhisper(ASRBase):
""" """
Uses MLX Whisper optimized for Apple Silicon. Uses MLX Whisper optimized for Apple Silicon.
@@ -150,12 +161,13 @@ class MLXWhisper(ASRBase):
sep = "" sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx import mlx.core as mx
from mlx_whisper.transcribe import ModelHolder, transcribe
if model_dir is not None: if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.") resolved_path = resolve_model_path(model_dir)
model_size_or_path = model_dir logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
model_size_or_path = str(resolved_path)
elif model_size is not None: elif model_size is not None:
model_size_or_path = self.translate_model_name(model_size) model_size_or_path = self.translate_model_name(model_size)
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.") logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
@@ -168,22 +180,8 @@ class MLXWhisper(ASRBase):
return transcribe return transcribe
def translate_model_name(self, model_name): def translate_model_name(self, model_name):
model_mapping = { from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
"tiny.en": "mlx-community/whisper-tiny.en-mlx", mlx_model_path = MLX_MODEL_MAPPING.get(model_name)
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
mlx_model_path = model_mapping.get(model_name)
if mlx_model_path: if mlx_model_path:
return mlx_model_path return mlx_model_path
else: else:
@@ -208,7 +206,7 @@ class MLXWhisper(ASRBase):
if segment.get("no_speech_prob", 0) > 0.9: if segment.get("no_speech_prob", 0) > 0.9:
continue continue
for word in segment.get("words", []): for word in segment.get("words", []):
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) token = ASRToken(word["start"], word["end"], word["word"])
tokens.append(token) tokens.append(token)
return tokens return tokens
@@ -218,9 +216,6 @@ class MLXWhisper(ASRBase):
def use_vad(self): def use_vad(self):
self.transcribe_kargs["vad_filter"] = True self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class OpenaiApiASR(ASRBase): class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for transcription.""" """Uses OpenAI's Whisper API for transcription."""
@@ -232,6 +227,7 @@ class OpenaiApiASR(ASRBase):
self.temperature = temperature self.temperature = temperature
self.load_model() self.load_model()
self.use_vad_opt = False self.use_vad_opt = False
self.direct_english_translation = False
self.task = "transcribe" self.task = "transcribe"
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
@@ -274,17 +270,15 @@ class OpenaiApiASR(ASRBase):
"temperature": self.temperature, "temperature": self.temperature,
"timestamp_granularities": ["word", "segment"], "timestamp_granularities": ["word", "segment"],
} }
if self.task != "translate" and self.original_language: if not self.direct_english_translation and self.original_language:
params["language"] = self.original_language params["language"] = self.original_language
if prompt: if prompt:
params["prompt"] = prompt params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params) transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript return transcript
def use_vad(self): def use_vad(self):
self.use_vad_opt = True self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"

View File

@@ -1,7 +1,9 @@
import sys
import numpy as np
import logging import logging
from typing import List, Tuple, Optional import sys
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -134,6 +136,11 @@ class OnlineASRProcessor:
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM." f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
) )
def new_speaker(self, change_speaker):
"""Handle speaker change event."""
self.process_iter()
self.init(offset=change_speaker.start)
def init(self, offset: Optional[float] = None): def init(self, offset: Optional[float] = None):
"""Initialize or reset the processing buffers.""" """Initialize or reset the processing buffers."""
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
@@ -151,21 +158,32 @@ class OnlineASRProcessor:
"""Append an audio chunk (a numpy array) to the current audio buffer.""" """Append an audio chunk (a numpy array) to the current audio buffer."""
self.audio_buffer = np.append(self.audio_buffer, audio) self.audio_buffer = np.append(self.audio_buffer, audio)
def insert_silence(self, silence_duration, offset): def start_silence(self):
""" if self.audio_buffer.size == 0:
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame return [], self.get_audio_buffer_end_time()
""" return self.process_iter()
# if self.transcript_buffer.buffer:
# self.committed.extend(self.transcript_buffer.buffer) def end_silence(self, silence_duration: Optional[float], offset: float):
# self.transcript_buffer.buffer = [] if not silence_duration or silence_duration <= 0:
return
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16) long_silence = silence_duration >= 5
self.insert_audio_chunk(gap_silence) if not long_silence:
gap_samples = int(self.SAMPLING_RATE * silence_duration)
if gap_samples > 0:
gap_silence = np.zeros(gap_samples, dtype=np.float32)
self.insert_audio_chunk(gap_silence)
else: else:
self.init(offset=silence_duration + offset) self.init(offset=silence_duration + offset)
self.global_time_offset += silence_duration self.global_time_offset += silence_duration
def insert_silence(self, silence_duration, offset):
"""
Backwards compatibility shim for legacy callers that still use insert_silence.
"""
self.end_silence(silence_duration, offset)
def prompt(self) -> Tuple[str, str]: def prompt(self) -> Tuple[str, str]:
""" """
Returns a tuple: (prompt, context), where: Returns a tuple: (prompt, context), where:
@@ -400,11 +418,11 @@ class OnlineASRProcessor:
) -> Transcript: ) -> Transcript:
sep = sep if sep is not None else self.asr.sep sep = sep if sep is not None else self.asr.sep
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None # probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens: if tokens:
start = offset + tokens[0].start start = offset + tokens[0].start
end = offset + tokens[-1].end end = offset + tokens[-1].end
else: else:
start = None start = None
end = None end = None
return Transcript(start, end, text, probability=probability) return Transcript(start, end, text)

View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python3
import logging
import platform
import time
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.warmup import warmup_asr
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
logger = logging.getLogger(__name__)
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
)
def create_tokenizer(lan):
"""returns an object that has split function that works like the one of MosesTokenizer"""
assert (
lan in WHISPER_LANG_CODES
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
if lan == "uk":
import tokenize_uk
class UkrainianTokenizer:
def split(self, text):
return tokenize_uk.tokenize_sents(text)
return UkrainianTokenizer()
# supported by fast-mosestokenizer
if (
lan
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
):
from mosestokenizer import MosesSentenceSplitter
return MosesSentenceSplitter(lan)
# the following languages are in Whisper, but not in wtpsplit:
if (
lan
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
):
logger.debug(
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
)
lan = None
from wtpsplit import WtP
# downloads the model from huggingface on the first use
wtp = WtP("wtp-canine-s-12l-no-adapters")
class WtPtok:
def split(self, sent):
return wtp.split(sent, lang_code=lan)
return WtPtok()
def backend_factory(
backend,
lan,
model_size,
model_cache_dir,
model_dir,
model_path,
lora_path,
direct_english_translation,
buffer_trimming,
buffer_trimming_sec,
confidence_validation,
warmup_file=None,
min_chunk_size=None,
):
backend_choice = backend
custom_reference = model_path or model_dir
resolved_root = None
has_mlx_weights = False
has_fw_weights = False
has_pytorch = False
if custom_reference:
resolved_root = resolve_model_path(custom_reference)
if resolved_root.is_dir():
model_info = detect_model_format(resolved_root)
has_mlx_weights = model_info.compatible_whisper_mlx
has_fw_weights = model_info.compatible_faster_whisper
has_pytorch = model_info.has_pytorch
else:
# Single file provided
has_pytorch = True
if backend_choice == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=lan)
else:
backend_choice = _normalize_backend_choice(
backend_choice,
resolved_root,
has_mlx_weights,
has_fw_weights,
)
if backend_choice == "faster-whisper":
asr_cls = FasterWhisperASR
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
model_override = str(resolved_root) if resolved_root is not None else None
elif backend_choice == "mlx-whisper":
asr_cls = MLXWhisper
if resolved_root is not None and not resolved_root.is_dir():
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
model_override = str(resolved_root) if resolved_root is not None else None
else:
asr_cls = WhisperASR
model_override = str(resolved_root) if resolved_root is not None else None
if custom_reference and not has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
)
t = time.time()
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
asr = asr_cls(
model_size=model_size,
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_override,
lora_path=lora_path if backend_choice == "whisper" else None,
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else:
tgt_language = lan # Whisper transcribes in this language
# Create the tokenizer
if buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming
asr.buffer_trimming_sec = buffer_trimming_sec
asr.backend_choice = backend_choice
return asr
def _normalize_backend_choice(
preferred_backend,
resolved_root,
has_mlx_weights,
has_fw_weights,
):
backend_choice = preferred_backend
if backend_choice == "auto":
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
return "mlx-whisper"
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
return "faster-whisper"
return "whisper"
if backend_choice == "mlx-whisper":
if not mlx_backend_available():
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
if resolved_root is not None and not has_mlx_weights:
raise FileNotFoundError(
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
)
if platform.system() != "Darwin":
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
return backend_choice
if backend_choice == "faster-whisper":
if not faster_backend_available():
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
if resolved_root is not None and not has_fw_weights:
raise FileNotFoundError(
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
)
return backend_choice
if backend_choice == "whisper":
return backend_choice
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")

156
whisperlivekit/metrics.py Normal file
View File

@@ -0,0 +1,156 @@
"""Lightweight ASR evaluation metrics — no external dependencies.
Provides WER (Word Error Rate) computation via word-level Levenshtein distance,
text normalization, and word-level timestamp accuracy metrics with greedy alignment.
"""
import re
import unicodedata
from typing import Dict, List, Optional
def normalize_text(text: str) -> str:
"""Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace."""
text = text.lower()
# Normalize unicode (e.g., accented chars to composed form)
text = unicodedata.normalize("NFC", text)
# Remove punctuation (keep letters, numbers, spaces, hyphens within words)
text = re.sub(r"[^\w\s\-']", " ", text)
# Collapse whitespace
text = re.sub(r"\s+", " ", text).strip()
return text
def compute_wer(reference: str, hypothesis: str) -> Dict:
"""Compute Word Error Rate using word-level Levenshtein edit distance.
Args:
reference: Ground truth transcription.
hypothesis: Predicted transcription.
Returns:
Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words.
WER can exceed 1.0 if there are more errors than reference words.
"""
ref_words = normalize_text(reference).split()
hyp_words = normalize_text(hypothesis).split()
n = len(ref_words)
m = len(hyp_words)
if n == 0:
return {
"wer": 0.0 if m == 0 else float(m),
"substitutions": 0,
"insertions": m,
"deletions": 0,
"ref_words": 0,
"hyp_words": m,
}
# DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions)
dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)]
for i in range(1, n + 1):
dp[i][0] = (i, 0, 0, i)
for j in range(1, m + 1):
dp[0][j] = (j, 0, j, 0)
for i in range(1, n + 1):
for j in range(1, m + 1):
if ref_words[i - 1] == hyp_words[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
sub = dp[i - 1][j - 1]
ins = dp[i][j - 1]
dele = dp[i - 1][j]
sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3])
ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3])
del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1)
dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0])
dist, subs, ins, dels = dp[n][m]
return {
"wer": dist / n,
"substitutions": subs,
"insertions": ins,
"deletions": dels,
"ref_words": n,
"hyp_words": m,
}
def compute_timestamp_accuracy(
predicted: List[Dict],
reference: List[Dict],
) -> Dict:
"""Compute timestamp accuracy by aligning predicted words to reference words.
Uses greedy left-to-right alignment on normalized text. For each matched pair,
computes the start-time delta (predicted - reference).
Args:
predicted: List of dicts with keys: word, start, end.
reference: List of dicts with keys: word, start, end.
Returns:
Dict with keys: mae_start, max_delta_start, median_delta_start,
n_matched, n_ref, n_pred. Returns None values if no matches found.
"""
if not predicted or not reference:
return {
"mae_start": None,
"max_delta_start": None,
"median_delta_start": None,
"n_matched": 0,
"n_ref": len(reference),
"n_pred": len(predicted),
}
# Normalize words for matching
pred_norm = [normalize_text(p["word"]) for p in predicted]
ref_norm = [normalize_text(r["word"]) for r in reference]
# Greedy left-to-right alignment
deltas_start = []
ref_idx = 0
for p_idx, p_word in enumerate(pred_norm):
if not p_word:
continue
# Scan forward in reference to find a match (allow small skips)
search_limit = min(ref_idx + 3, len(ref_norm))
for r_idx in range(ref_idx, search_limit):
if ref_norm[r_idx] == p_word:
delta = predicted[p_idx]["start"] - reference[r_idx]["start"]
deltas_start.append(delta)
ref_idx = r_idx + 1
break
if not deltas_start:
return {
"mae_start": None,
"max_delta_start": None,
"median_delta_start": None,
"n_matched": 0,
"n_ref": len(reference),
"n_pred": len(predicted),
}
abs_deltas = [abs(d) for d in deltas_start]
sorted_abs = sorted(abs_deltas)
n = len(sorted_abs)
if n % 2 == 1:
median = sorted_abs[n // 2]
else:
median = (sorted_abs[n // 2 - 1] + sorted_abs[n // 2]) / 2
return {
"mae_start": sum(abs_deltas) / len(abs_deltas),
"max_delta_start": max(abs_deltas),
"median_delta_start": median,
"n_matched": len(deltas_start),
"n_ref": len(reference),
"n_pred": len(predicted),
}

View File

@@ -0,0 +1,84 @@
"""Lightweight runtime metrics for AudioProcessor sessions.
Zero external dependencies. Negligible overhead when not queried —
just integer increments and list appends during normal operation.
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List
logger = logging.getLogger(__name__)
@dataclass
class SessionMetrics:
"""Per-session metrics collected by AudioProcessor."""
session_start: float = 0.0
total_audio_duration_s: float = 0.0
total_processing_time_s: float = 0.0
# Chunk / call counters
n_chunks_received: int = 0
n_transcription_calls: int = 0
n_tokens_produced: int = 0
n_responses_sent: int = 0
# Per-call ASR latency (seconds)
transcription_durations: List[float] = field(default_factory=list)
# Silence
n_silence_events: int = 0
total_silence_duration_s: float = 0.0
# --- Computed properties ---
@property
def rtf(self) -> float:
"""Real-time factor: processing_time / audio_duration."""
if self.total_audio_duration_s <= 0:
return 0.0
return self.total_processing_time_s / self.total_audio_duration_s
@property
def avg_latency_ms(self) -> float:
"""Average per-call ASR latency in milliseconds."""
if not self.transcription_durations:
return 0.0
return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000
@property
def p95_latency_ms(self) -> float:
"""95th percentile per-call ASR latency in milliseconds."""
if not self.transcription_durations:
return 0.0
sorted_d = sorted(self.transcription_durations)
idx = int(len(sorted_d) * 0.95)
idx = min(idx, len(sorted_d) - 1)
return sorted_d[idx] * 1000
def to_dict(self) -> Dict:
"""Serialize to a plain dict (JSON-safe)."""
return {
"session_start": self.session_start,
"total_audio_duration_s": round(self.total_audio_duration_s, 3),
"total_processing_time_s": round(self.total_processing_time_s, 3),
"rtf": round(self.rtf, 3),
"n_chunks_received": self.n_chunks_received,
"n_transcription_calls": self.n_transcription_calls,
"n_tokens_produced": self.n_tokens_produced,
"n_responses_sent": self.n_responses_sent,
"avg_latency_ms": round(self.avg_latency_ms, 2),
"p95_latency_ms": round(self.p95_latency_ms, 2),
"n_silence_events": self.n_silence_events,
"total_silence_duration_s": round(self.total_silence_duration_s, 3),
}
def log_summary(self) -> None:
"""Emit a structured log line summarising the session."""
self.total_processing_time_s = sum(self.transcription_durations)
d = self.to_dict()
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
logger.info(f"SESSION_METRICS {d}")

View File

@@ -0,0 +1,17 @@
"""Shared MLX model name mapping used by both SimulStreaming and LocalAgreement backends."""
MLX_MODEL_MAPPING = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}

View File

@@ -0,0 +1,215 @@
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple, Union
@dataclass
class ModelInfo:
"""Information about detected model format and files in a directory."""
path: Optional[Path] = None
pytorch_files: List[Path] = field(default_factory=list)
compatible_whisper_mlx: bool = False
compatible_faster_whisper: bool = False
@property
def has_pytorch(self) -> bool:
return len(self.pytorch_files) > 0
@property
def is_sharded(self) -> bool:
return len(self.pytorch_files) > 1
@property
def primary_pytorch_file(self) -> Optional[Path]:
"""Return the primary PyTorch file (or first shard for sharded models)."""
if not self.pytorch_files:
return None
return self.pytorch_files[0]
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
"""
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
CTranslate2 models have specific companion files that distinguish them
from PyTorch .bin files.
"""
n_indicators = 0
for indicator in CT2_INDICATOR_FILES: #test 1
if (directory / indicator).exists():
n_indicators += 1
if n_indicators == 0:
return False
config_path = directory / "config.json" #test 2
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
if config.get("model_type") == "whisper": #test 2
return False
except (json.JSONDecodeError, IOError):
pass
return True
def _collect_pytorch_files(directory: Path) -> List[Path]:
"""
Collect all PyTorch checkpoint files from a directory.
Handles:
- Single files: model.safetensors, pytorch_model.bin, *.pt
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
- Index-based sharded models (reads index file to find shards)
Returns files sorted appropriately (shards in order, or single file).
"""
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
index_path = directory / index_name
if index_path.exists():
try:
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
if weight_map:
shard_names = sorted(set(weight_map.values()))
shards = [directory / name for name in shard_names if (directory / name).exists()]
if shards:
return shards
except (json.JSONDecodeError, IOError):
pass
sharded_groups = {}
single_files = {}
for file in directory.iterdir():
if not file.is_file():
continue
filename = file.name
suffix = file.suffix.lower()
if filename.startswith("adapter_"):
continue
match = SHARDED_PATTERN.match(filename)
if match:
base_name, shard_idx, total_shards, ext = match.groups()
key = (base_name, ext, int(total_shards))
if key not in sharded_groups:
sharded_groups[key] = []
sharded_groups[key].append((int(shard_idx), file))
continue
if filename == "model.safetensors":
single_files[0] = file # Highest priority
elif filename == "pytorch_model.bin":
single_files[1] = file
elif suffix == ".pt":
single_files[2] = file
elif suffix == ".safetensors" and not filename.startswith("adapter"):
single_files[3] = file
for (base_name, ext, total_shards), shards in sharded_groups.items():
if len(shards) == total_shards:
return [path for _, path in sorted(shards)]
for priority in sorted(single_files.keys()):
return [single_files[priority]]
return []
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
"""
Detect the model format in a given path.
This function analyzes a file or directory to determine:
- What PyTorch checkpoint files are available (including sharded models)
- Whether the directory contains MLX Whisper weights
- Whether the directory contains Faster-Whisper (CTranslate2) weights
Args:
model_path: Path to a model file or directory
Returns:
ModelInfo with detected format information
"""
path = Path(model_path)
info = ModelInfo(path=path)
if path.is_file():
suffix = path.suffix.lower()
if suffix in {".pt", ".safetensors", ".bin"}:
info.pytorch_files = [path]
return info
if not path.is_dir():
return info
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
if filename in MLX_WHISPER_MARKERS:
info.compatible_whisper_mlx = True
if filename in FASTER_WHISPER_MARKERS:
if _is_ct2_model_bin(path, filename):
info.compatible_faster_whisper = True
info.pytorch_files = _collect_pytorch_files(path)
return info
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
"""
Inspect the provided path and determine which model formats are available.
This is a compatibility wrapper around detect_model_format().
Returns:
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
compatible_whisper_mlx: True if MLX weights exist in this folder.
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
"""
info = detect_model_format(model_path)
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
def resolve_model_path(model_path: Union[str, Path]) -> Path:
"""
Return a local path for the provided model reference.
If the path does not exist locally, it is treated as a Hugging Face repo id
and downloaded via snapshot_download.
"""
path = Path(model_path).expanduser()
if path.exists():
return path
try:
from huggingface_hub import snapshot_download
except ImportError as exc:
raise FileNotFoundError(
f"Model path '{model_path}' does not exist locally and huggingface_hub "
"is not installed to download it."
) from exc
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
return downloaded_path

View File

@@ -1,6 +1,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
def parse_args(): def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server") parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument( parser.add_argument(
@@ -81,14 +82,14 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--min-chunk-size", "--min-chunk-size",
type=float, type=float,
default=0.5, default=0.1,
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.", help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="small", default="base",
dest='model_size', dest='model_size',
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.", help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
) )
@@ -105,6 +106,13 @@ def parse_args():
default=None, default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.", help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
) )
parser.add_argument(
"--lora-path",
type=str,
default=None,
dest="lora_path",
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
)
parser.add_argument( parser.add_argument(
"--lan", "--lan",
"--language", "--language",
@@ -114,11 +122,10 @@ def parse_args():
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.", help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
) )
parser.add_argument( parser.add_argument(
"--task", "--direct-english-translation",
type=str, action="store_true",
default="transcribe", default=False,
choices=["transcribe", "translate"], help="Use Whisper to directly translate to english.",
help="Transcribe or translate.",
) )
parser.add_argument( parser.add_argument(
@@ -130,11 +137,18 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend-policy",
type=str, type=str,
default="simulstreaming", default="simulstreaming",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"], choices=["1", "2", "simulstreaming", "localagreement"],
help="Load only this backend for Whisper processing.", help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
)
parser.add_argument(
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
) )
parser.add_argument( parser.add_argument(
"--no-vac", "--no-vac",
@@ -289,14 +303,6 @@ def parse_args():
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.", help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
) )
simulstreaming_group.add_argument(
"--preload-model-count",
type=int,
default=1,
dest="preload_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--nllb-backend", "--nllb-backend",
type=str, type=str,
@@ -312,10 +318,12 @@ def parse_args():
) )
args = parser.parse_args() args = parser.parse_args()
args.transcription = not args.no_transcription args.transcription = not args.no_transcription
args.vad = not args.no_vad args.vad = not args.no_vad
args.vac = not args.no_vac
delattr(args, 'no_transcription') delattr(args, 'no_transcription')
delattr(args, 'no_vad') delattr(args, 'no_vad')
delattr(args, 'no_vac')
return args
from whisperlivekit.config import WhisperLiveKitConfig
return WhisperLiveKitConfig.from_namespace(args)

View File

@@ -1,106 +0,0 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
probability=0.95
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
probability=0.95
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
probability=0.95
)
)
return tokens
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens

View File

@@ -1,154 +0,0 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
DEBUG = False
def is_punctuation(token):
if token.is_punctuation():
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
):
return Line(
speaker = token.corrected_speaker,
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
start = token.start,
end = token.end,
detected_language=token.detected_language
)
def append_token_to_last_line(lines, sep, token):
if not lines:
lines.append(new_line(token))
else:
if token.text:
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
translation_buffer = state.translation_buffer
last_validated_token = state.last_validated_token
previous_speaker = 1
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
last_punctuation = None
for i, token in enumerate(tokens[last_validated_token:]):
speaker = int(token.speaker)
token.corrected_speaker = speaker
if not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
token.corrected_speaker = 1
token.validated_speaker = True
else:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
previous_speaker = token.corrected_speaker
previous_speaker = 1
lines = []
for token in tokens:
if int(token.corrected_speaker) != int(previous_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
previous_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text

View File

@@ -1,12 +1,22 @@
import torch
import numpy as np
import warnings import warnings
from pathlib import Path from pathlib import Path
import numpy as np
import torch
""" """
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
""" """
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')): def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file.""" """Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device) model = torch.jit.load(model_path, map_location=device)
@@ -14,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
return model return model
class OnnxWrapper(): class OnnxSession():
"""ONNX Runtime wrapper for Silero VAD model.""" """
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False): def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime import onnxruntime
opts = onnxruntime.SessionOptions() opts = onnxruntime.SessionOptions()
@@ -31,13 +41,28 @@ class OnnxWrapper():
else: else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states() self.path = path
if '16k' in path: if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!') warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000] self.sample_rates = [16000]
else: else:
self.sample_rates = [8000, 16000] self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):
if x.dim() == 1: if x.dim() == 1:
x = x.unsqueeze(0) x = x.unsqueeze(0)
@@ -90,7 +115,7 @@ class OnnxWrapper():
out, state = ort_outs out, state = ort_outs
self._state = torch.from_numpy(state) self._state = torch.from_numpy(state)
else: else:
raise ValueError() raise ValueError(f"Unsupported sampling rate {sr}. Supported: {self.sample_rates} (with sample sizes 256 for 8000, 512 for 16000)")
self._context = x[..., -context_size:] self._context = x[..., -context_size:]
self._last_sr = sr self._last_sr = sr
@@ -100,59 +125,63 @@ class OnnxWrapper():
return out return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16): def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
""" """Get the path to the ONNX model file."""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16] available_ops = [15, 16]
if onnx and opset_version not in available_ops: if opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}') raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'vad_models' data_dir = current_dir / 'silero_vad_models'
if onnx: if opset_version == 16:
if opset_version == 16: model_name = 'silero_vad.onnx'
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else: else:
model_name = 'silero_vad.jit' model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name model_path = data_dir / model_name
if not model_path.exists(): if not model_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"Model file not found: {model_path}\n" f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files." f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
if onnx:
try: return model_path
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError( def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
"ONNX runtime not available. Install with: pip install onnxruntime\n" """
"Or use JIT model by setting onnx=False" Load a shared ONNX session for Silero VAD.
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model = init_jit_model(str(model_path)) model_path = Path(model_path)
model = init_jit_model(str(model_path))
return model return model
@@ -226,8 +255,8 @@ class VADIterator:
if not torch.is_tensor(x): if not torch.is_tensor(x):
try: try:
x = torch.Tensor(x) x = torch.Tensor(x)
except: except (ValueError, TypeError, RuntimeError) as exc:
raise TypeError("Audio cannot be casted to tensor. Cast it manually") raise TypeError("Audio cannot be cast to tensor. Cast it manually") from exc
window_size_samples = len(x[0]) if x.dim() == 2 else len(x) window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples self.current_sample += window_size_samples
@@ -276,19 +305,22 @@ class FixedVADIterator(VADIterator):
elif r is not None: elif r is not None:
if "end" in r: if "end" in r:
ret["end"] = r["end"] ret["end"] = r["end"]
if "start" in r and "end" in ret: if "start" in r:
del ret["end"] ret["start"] = r["start"]
if "end" in ret:
del ret["end"]
return ret if ret != {} else None return ret if ret != {} else None
if __name__ == "__main__": if __name__ == "__main__":
model = load_silero_vad(onnx=False) # vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(model) vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32) audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 512 samples: {result}") print(f" 512 samples: {result}")
# test with 511 samples # test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32) audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -0,0 +1,552 @@
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from .config import AlignAttConfig
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class AlignAttBase(ABC):
"""
Abstract base class for AlignAtt streaming decoders.
Provides shared logic for both PyTorch and MLX implementations:
- Properties (speaker, global_time_offset)
- Pure-Python methods (warmup, trim_context, refresh_segment, etc.)
- Template infer() with abstract hooks for tensor-specific operations
- Post-decode logic (token splitting, timestamped word building)
Subclasses must implement ~20 abstract methods for tensor-specific ops.
"""
# === Properties ===
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
# === Constructor helpers ===
def _base_init(self, cfg: AlignAttConfig, model):
"""Common initialization — call from subclass __init__."""
self.model = model
self.cfg = cfg
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task,
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = model.dims.n_text_ctx
self.num_decoder_layers = len(model.decoder.blocks)
if cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = cfg.max_context_tokens
def _init_state_common(self, cfg: AlignAttConfig):
"""Common state initialization — call from subclass _init_state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
# === Shared concrete methods ===
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task,
)
self.state.tokenizer = self.tokenizer
def trim_context(self):
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
after = 0 if self.cfg.static_init_prompt is None else len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
def segments_len(self):
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self):
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def _clean_cache(self):
self.state.clean_cache()
def debug_print_tokens(self, tokens):
for i in range(min(self.cfg.beam_size, tokens.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
# === Language detection ===
def _detect_language_if_needed(self, encoder_feature):
if (
self.cfg.language == "auto"
and self.state.detected_language is None
and self.state.first_timestamp
):
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
# === Template infer() ===
def infer(self, is_last=False):
"""Main inference — template method calling abstract hooks for tensor ops."""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
input_segments = self._concat_segments()
encoder_feature, content_mel_len = self._encode(input_segments)
self._evaluate(encoder_feature)
self._detect_language_if_needed(encoder_feature)
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = self._init_sum_logprobs()
completed = False
token_len_before = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced = 0
most_attended_frame = None
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced += 1
if tokens_produced > max_tokens:
logger.warning(
f"[Loop Detection] Too many tokens ({tokens_produced}) "
f"for {audio_duration_s:.2f}s audio. Breaking."
)
current_tokens = current_tokens[:, :token_len_before]
break
tokens_for_logits = current_tokens if new_segment else current_tokens[:, -1:]
logits, cross_attns = self._get_logits_and_cross_attn(
tokens_for_logits, encoder_feature
)
self._evaluate(logits)
accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self._check_no_speech(logits):
break
logits = logits[:, -1, :]
if new_segment:
logits = self._suppress_blank_tokens(logits)
new_segment = False
logits = self._apply_token_suppression(logits)
logits = self._apply_dry_penalty(logits, current_tokens)
current_tokens, completed = self._update_tokens(
current_tokens, logits, sum_logprobs
)
self._evaluate(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
frames_list, most_attended_frame = self._get_attended_frames(attn)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in frames_list
]
l_absolute_timestamps.append(absolute_timestamps[0])
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
if completed:
current_tokens = current_tokens[:, :-1]
break
# Rewind check
if (
not is_last
and self.state.last_attend_frame - most_attended_frame
> self.cfg.rewind_threshold
):
if current_tokens.shape[1] > 1 and self._is_special_token(current_tokens):
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(
f"[rewind detected] current: {most_attended_frame}, "
f"last: {self.state.last_attend_frame}"
)
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = self._rewind_tokens()
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (
4 if is_last else self.cfg.frame_threshold
):
logger.debug(
f"attention reaches the end: {most_attended_frame}/{content_mel_len}"
)
current_tokens = current_tokens[:, :-1]
break
# Post-decode: split tokens and build timestamped words
tokens_to_split = self._tokens_to_list(current_tokens, token_len_before)
if self.state.pending_incomplete_tokens:
logger.debug(
f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} "
f"pending tokens: {self.state.pending_incomplete_tokens}"
)
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
new_hypothesis, split_words, split_tokens = self._split_tokens(
tokens_to_split, fire_detected, is_last
)
new_tokens_tensor = self._make_new_tokens_tensor(new_hypothesis)
self.state.tokens.append(new_tokens_tensor)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = self._build_timestamped_words(
split_words, split_tokens, l_absolute_timestamps
)
self._handle_pending_tokens(split_words, split_tokens)
return timestamped_words
# === Post-decode shared helpers ===
def _split_tokens(self, tokens_list, fire_detected, is_last):
"""Split token list into words. Returns (hypothesis, split_words, split_tokens)."""
if fire_detected or is_last:
new_hypothesis = tokens_list
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_list)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
return new_hypothesis, split_words, split_tokens
def _build_timestamped_words(self, split_words, split_tokens, l_absolute_timestamps):
"""Build list of timestamped ASRToken from split words."""
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
cleaned = word.replace(replacement_char, "")
if not cleaned.strip():
logger.debug(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
logger.debug(f"[UTF-8 Filter] Cleaned {repr(word)} -> {repr(cleaned)}")
word = cleaned
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
logger.warning(
f"Timestamp index {timestamp_idx} out of range, using last timestamp"
)
current_timestamp = (
l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
)
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language,
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
return timestamped_words
def _handle_pending_tokens(self, split_words, split_tokens):
"""Handle incomplete UTF-8 tokens for next chunk."""
MAX_PENDING_TOKENS = 10
MAX_PENDING_RETRIES = 2
replacement_char = "\ufffd"
if split_words and replacement_char in split_words[-1]:
self.state.pending_retries += 1
if self.state.pending_retries > MAX_PENDING_RETRIES:
logger.warning(
f"[UTF-8 Fix] Dropping {len(split_tokens[-1])} incomplete tokens "
f"after {MAX_PENDING_RETRIES} retries (won't resolve)"
)
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
elif len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
f"incomplete tokens for next chunk (retry {self.state.pending_retries})"
)
else:
logger.warning(
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
)
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
else:
self.state.pending_incomplete_tokens = []
self.state.pending_retries = 0
# === Repetition penalty ===
def _apply_dry_penalty(self, logits, current_tokens):
"""DRY penalty v0: penalize tokens that would extend a verbatim repetition.
See https://github.com/oobabooga/text-generation-webui/pull/5677
Scans the decoded sequence for positions where the current suffix already
appeared --> for each such match, the token that followed it in the past is
penalised exponentially with the match length
"""
eot = self.tokenizer.eot
seq = current_tokens[0].tolist()
if len(seq) < 5:
return logits
last = seq[-1]
if last >= eot:
return logits
penalties = {}
for i in range(len(seq) - 2, -1, -1):
if seq[i] != last:
continue
next_tok = seq[i + 1]
if next_tok >= eot:
continue
length = 1
while length < 50:
j, k = i - length, len(seq) - 1 - length
if j < 0 or k <= i:
break
if seq[j] != seq[k] or seq[j] >= eot:
break
length += 1
if next_tok not in penalties or length > penalties[next_tok]:
penalties[next_tok] = length
if penalties:
max_len = max(penalties.values())
if max_len >= 4:
logger.debug(f"[DRY] penalising {len(penalties)} tokens (longest match: {max_len})")
for tok, length in penalties.items():
if length >= 2:
logits[:, tok] = logits[:, tok] - 1.0 * 2.0 ** (length - 2)
return logits
# === Abstract methods — subclass must implement ===
@abstractmethod
def _init_state(self, cfg: AlignAttConfig):
"""Initialize per-session decoder state."""
...
@abstractmethod
def init_tokens(self):
"""Initialize token sequence with framework-specific tensors."""
...
@abstractmethod
def init_context(self):
"""Initialize context buffer with framework-specific TokenBuffer."""
...
@abstractmethod
def insert_audio(self, segment=None):
"""Insert audio segment into buffer."""
...
@abstractmethod
def _current_tokens(self):
"""Build current token tensor for decoding."""
...
@abstractmethod
def fire_at_boundary(self, feature):
"""Check if we should fire at word boundary."""
...
@abstractmethod
def lang_id(self, encoder_features):
"""Language detection from encoder features. Returns (tokens, probs)."""
...
@abstractmethod
def _concat_segments(self):
"""Concatenate audio segments into single array/tensor."""
...
@abstractmethod
def _encode(self, input_segments):
"""Encode audio. Returns (encoder_feature, content_mel_len)."""
...
@abstractmethod
def _init_sum_logprobs(self):
"""Create zero sum_logprobs tensor for beam search."""
...
@abstractmethod
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
"""Get logits and cross-attention from decoder. Returns (logits, cross_attns)."""
...
@abstractmethod
def _check_no_speech(self, logits):
"""Check no_speech probability at start of segment. Returns True to break."""
...
@abstractmethod
def _suppress_blank_tokens(self, logits):
"""Suppress blank/EOT tokens at segment start. Returns modified logits."""
...
@abstractmethod
def _apply_token_suppression(self, logits):
"""Apply general token suppression. Returns modified logits."""
...
@abstractmethod
def _update_tokens(self, current_tokens, logits, sum_logprobs):
"""Update tokens via decoder. Returns (current_tokens, completed)."""
...
@abstractmethod
def _process_cross_attention(self, accumulated_cross_attns, content_mel_len):
"""Process cross-attention for alignment. Returns attention tensor."""
...
@abstractmethod
def _get_attended_frames(self, attn):
"""Get most attended frames. Returns (frames_as_python_list, first_frame_int)."""
...
@abstractmethod
def _is_special_token(self, current_tokens):
"""Check if second-to-last token is a special token (>= DEC_PAD)."""
...
@abstractmethod
def _rewind_tokens(self):
"""Concatenate state tokens for rewind. Returns token tensor."""
...
@abstractmethod
def _tokens_to_list(self, current_tokens, start_col):
"""Extract tokens as Python list from start_col onwards."""
...
@abstractmethod
def _make_new_tokens_tensor(self, hypothesis):
"""Create tensor from hypothesis token list, repeated for beam search."""
...
@abstractmethod
def _evaluate(self, tensor):
"""Evaluate lazy tensor (mx.eval for MLX, no-op for PyTorch)."""
...

View File

@@ -1,115 +1,104 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
from whisperlivekit.warmup import load_file
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc import gc
import logging
import os
import platform
import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try: HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt
else: else:
try: mlx_model_mapping = {}
from faster_whisper import WhisperModel MLXAlignAtt = None
HAS_FASTER_WHISPER = True HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
except ImportError: if HAS_FASTER_WHISPER:
HAS_FASTER_WHISPER = False from faster_whisper import WhisperModel
else:
def model_path_and_type(model_path): WhisperModel = None
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
compatible_whisper_mlx = True
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor: class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000 SAMPLING_RATE = 16000
def __init__( def __init__(self, asr, logfile=sys.stderr):
self,
asr,
logfile=sys.stderr,
):
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = [] self.model = self._create_alignatt()
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
#can be moved
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def load_new_backend(self): def _create_alignatt(self):
model = self.asr.get_new_model_instance() """Create the AlignAtt decoder instance based on ASR mode."""
self.model = PaddedAlignAttWhisper( if self.asr.use_full_mlx and HAS_MLX_WHISPER:
cfg=self.asr.cfg, return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
loaded_model=model, else:
mlx_encoder=self.asr.mlx_encoder, return AlignAtt(
fw_encoder=self.asr.fw_encoder, cfg=self.asr.cfg,
loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
) )
def insert_silence(self, silence_duration, offset): def start_silence(self):
""" tokens, processed_upto = self.process_iter(is_last=True)
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame return tokens, processed_upto
"""
if silence_duration < 5: def end_silence(self, silence_duration, offset):
gap_silence = torch.zeros(int(16000*silence_duration)) """Handle silence period."""
self.model.insert_audio(gap_silence) self.end += silence_duration
# self.global_time_offset += silence_duration long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
else: if not long_silence:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer. gap_len = int(16000 * silence_duration)
if gap_len > 0:
if self.asr.use_full_mlx:
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence)
if long_silence:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset self.model.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
self.end = audio_stream_end_time
# Convert numpy array to torch tensor if self.asr.use_full_mlx:
audio_tensor = torch.from_numpy(audio).float() self.model.insert_audio(audio)
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. else:
self.model.insert_audio(audio_tensor) audio_tensor = torch.from_numpy(audio).float()
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker): def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True) """Handle speaker change event."""
self.model.refresh_segment(complete=True) self.process_iter(is_last=True)
self.model.speaker = change_speaker.speaker self.model.refresh_segment(complete=True)
self.global_time_offset = change_speaker.start self.model.speaker = change_speaker.speaker
self.model.global_time_offset = change_speaker.start
def get_buffer(self): def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
@@ -123,15 +112,16 @@ class SimulStreamingOnlineProcessor:
""" """
try: try:
timestamped_words = self.model.infer(is_last=is_last) timestamped_words = self.model.infer(is_last=is_last)
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
if not timestamped_words:
return [], self.end
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.committed.extend(timestamped_words)
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}") logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end return [], self.end
@@ -139,6 +129,10 @@ class SimulStreamingOnlineProcessor:
def warmup(self, audio, init_prompt=""): def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model.""" """Warmup the SimulStreaming model."""
try: try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio) self.model.insert_audio(audio)
self.model.infer(True) self.model.infer(True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -147,14 +141,15 @@ class SimulStreamingOnlineProcessor:
logger.exception(f"SimulStreaming warmup failed: {e}") logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self): def __del__(self):
# free the model and add a new model to stack.
# del self.model
gc.collect() gc.collect()
torch.cuda.empty_cache() if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
# self.asr.new_model_to_stack() try:
self.model.remove_hooks() torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR():
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
@@ -169,32 +164,51 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam' self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False self.fast_encoder = False
self._resolved_model_path = None
self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path: if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path) resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path)
model_info = detect_model_format(resolved_model_path)
compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper
if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
)
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
elif self.model_size is not None: elif self.model_size is not None:
model_mapping = { self.model_name = self.model_size
'tiny': './tiny.pt', else:
'base': './base.pt', raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
'small': './small.pt',
'medium': './medium.pt', is_multilingual = not self.model_name.endswith(".en")
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt', self.encoder_backend = self._resolve_encoder_backend(
'base.en': './base.en.pt', preferred_backend,
'small.en': './small.en.pt', compatible_whisper_mlx,
'tiny.en': './tiny.en.pt', compatible_faster_whisper,
'large-v2': './large-v2.pt', )
'large-v3': './large-v3.pt', self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
'large': './large-v3.pt' if self.encoder_backend == "whisper":
} self.disable_fast_encoder = True
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
# MLX full decoder disabled by default — MLXAlignAtt has known issues
# with token generation after punctuation. Users can opt-in with
# --use-full-mlx if they want to test it.
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
# if not hasattr(self, '_full_mlx_disabled'):
# self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= not self.model_name.endswith(".en"), tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size, segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold, frame_threshold=self.frame_threshold,
language=self.lan, language=self.lan,
@@ -203,7 +217,7 @@ class SimulStreamingASR():
cif_ckpt_path=self.cif_ckpt_path, cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam", decoder_type="beam",
beam_size=self.beams, beam_size=self.beams,
task=self.task, task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire, never_fire=self.never_fire,
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
@@ -211,84 +225,135 @@ class SimulStreamingASR():
) )
# Set up tokenizer for translation if needed # Set up tokenizer for translation if needed
if self.task == "translate": if self.direct_english_translation:
self.tokenizer = self.set_translate_task() self.tokenizer = self.set_translate_task()
else: else:
self.tokenizer = None self.tokenizer = None
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
self.shared_model = None
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None:
self.mlx_encoder, self.fw_encoder = None, None mlx_model_path = str(self._resolved_model_path)
if not self.disable_fast_encoder: else:
if HAS_MLX_WHISPER: mlx_model_path = mlx_model_mapping.get(self.model_name)
print('Simulstreaming will use MLX whisper to increase encoding speed.') if not mlx_model_path:
if self.model_path and compatible_whisper_mlx: raise FileNotFoundError(
mlx_model = self.model_path f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
fw_model = self.model_path
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
fw_model,
device='auto',
compute_type='auto',
) )
self.fast_encoder = True self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper":
print('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path)
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
fw_model,
device='auto',
compute_type='auto',
)
self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
self.models = [self.load_model() for i in range(self.preload_model_count)]
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
choice = preferred_backend or "auto"
if self.disable_fast_encoder:
return "whisper"
if choice == "whisper":
return "whisper"
if choice == "mlx-whisper":
if not self._can_use_mlx(compatible_whisper_mlx):
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
return "mlx-whisper"
if choice == "faster-whisper":
if not self._can_use_faster(compatible_faster_whisper):
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
return "faster-whisper"
if choice == "openai-api":
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
# auto mode
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
return "mlx-whisper"
if self._can_use_faster(compatible_faster_whisper):
return "faster-whisper"
return "whisper"
def _has_custom_model_path(self):
return self._resolved_model_path is not None
def _can_use_mlx(self, compatible_whisper_mlx):
if not HAS_MLX_WHISPER:
return False
if self._has_custom_model_path():
return compatible_whisper_mlx
return self.model_name in mlx_model_mapping
def _can_use_faster(self, compatible_faster_whisper):
if not HAS_FASTER_WHISPER:
return False
if self._has_custom_model_path():
return compatible_faster_whisper
return True
def load_model(self): def load_model(self):
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model( whisper_model = load_model(
name=self.model_path if self.model_path else self.model_name, name=model_ref,
download_root=self.model_path, download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder, decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads custom_alignment_heads=self.custom_alignment_heads,
) lora_path=lora_path,
)
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None: if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float() warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder: if self.fast_encoder:
temp_model = PaddedAlignAttWhisper( temp_model = AlignAtt(
cfg=self.cfg, cfg=self.cfg,
loaded_model=whisper_model, loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder, mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder, fw_encoder=self.fw_encoder,
) )
temp_model.warmup(warmup_audio) temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else: else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None) whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
return whisper_model return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self): def set_translate_task(self):
"""Set up translation task.""" """Set up translation task."""
if self.cfg.language == 'auto': if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto') raise ValueError('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer( return tokenizer.get_tokenizer(
multilingual=True, multilingual=True,
language=self.cfg.language, language=self.cfg.language,

View File

@@ -1,17 +1,32 @@
from .whisper.decoding import PyTorchInference from torch import Tensor
from whisperlivekit.whisper.decoding import PyTorchInference
# extention of PyTorchInference for beam search
class BeamPyTorchInference(PyTorchInference): class BeamPyTorchInference(PyTorchInference):
"""Extension of PyTorchInference for beam search with cross-attention support."""
def _kv_modules(self): def _kv_cache_ids(self):
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks] """Get cache_id strings for self-attention key/value modules."""
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks] key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
return key_modules + value_modules value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
return key_ids + value_ids
def rearrange_kv_cache(self, source_indices): def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))): if source_indices != list(range(len(source_indices))):
for module_cache_id in self._kv_modules(): for cache_id in self._kv_cache_ids():
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach() if cache_id in self.kv_cache:
from torch import Tensor self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) def logits(
self,
tokens: Tensor,
audio_features: Tensor,
return_cross_attn: bool = False,
):
"""Get logits, optionally returning cross-attention weights."""
return self.model.decoder(
tokens, audio_features,
kv_cache=self.kv_cache,
return_cross_attn=return_cross_attn,
)

View File

@@ -1,8 +1,7 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
@dataclass @dataclass
class AlignAttConfig(): class AlignAttConfig():
eval_data_path: str = "tmp" eval_data_path: str = "tmp"

View File

@@ -0,0 +1,97 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
@dataclass
class DecoderState:
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[torch.Tensor] = field(default_factory=list)
initial_tokens: Optional[torch.Tensor] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[torch.Tensor] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
CIFLinear: Optional[torch.nn.Module] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens_fn: Any = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
# Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None:
# Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
"""
Reset transient state for a new segment.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.pending_retries = 0
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = {}
self.first_timestamp = None

View File

@@ -1,43 +0,0 @@
class Tokens:
def __init__(self, tokens):
self.tokens = tokens
# def clone(self):
# return Tokens(self.tokens.clone())
def __str__(self):
return str(self.tokens.tolist())
def __repr__(self):
return self.__str__()
class BeamTokens(Tokens):
def __init__(self, tokens, beam_size):
self.tokens = tokens
self.beam_size = beam_size
def clone(self):
return BeamTokens(self.tokens.clone())
def __str__(self):
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
def __repr__(self):
return self.__str__()
def as_text(self, tokenizer):
return tokenizer.decode(self.tokens)
class Logits(Tokens):
def __init__(self, logits):
super().__init__(logits)
# def clone(self):
# return Logits(self.tokens.clone(), self.beam_size)
def __str__(self):
# return "abc"
return f"Logits({self.tokens.shape})"
def __repr__(self):
return self.__str__()

View File

@@ -0,0 +1,11 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -0,0 +1,78 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
pending_retries: int = 0
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.pending_retries = 0
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -0,0 +1,219 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -0,0 +1,421 @@
"""MLX whisper AlignAtt streaming decoder."""
import logging
from typing import Any, List, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..align_att_base import DEC_PAD, AlignAttBase
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
logger = logging.getLogger(__name__)
class MLXTokenBuffer:
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""Apply median filter along the last axis."""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt(AlignAttBase):
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
Runs entirely on MLX, with no PyTorch dependencies for inference.
"""
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
# Common init (sets self.model, self.cfg, decode_options, etc.)
self._base_init(cfg, mlx_model)
logger.info(f"MLX Model dimensions: {self.model.dims}")
# Per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
self._init_state_common(cfg)
# CIF: MLX doesn't support CIF checkpoint loading
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning(
"CIF checkpoint provided but MLX CIF not implemented. "
"Using always_fire=True"
)
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
# Suppress tokens
suppress_tokens = [
self.tokenizer.transcribe, self.tokenizer.translate,
self.tokenizer.sot, self.tokenizer.sot_prev,
self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
# Decoder type
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(
self.model, self.state.initial_token_length,
)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size,
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
# === Abstract method implementations ===
def init_tokens(self):
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32,
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def init_context(self):
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev],
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def insert_audio(self, segment=None):
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
)
if len(self.state.tokens) > 1:
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _current_tokens(self) -> mx.array:
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True # MLX CIF not implemented
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(
list(self.tokenizer.all_language_tokens), dtype=mx.int32,
)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(
self.tokenizer.all_language_tokens,
self.tokenizer.all_language_codes,
)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def _concat_segments(self):
if len(self.state.segments) > 1:
return np.concatenate(self.state.segments, axis=0)
return self.state.segments[0]
def _encode(self, input_segments):
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES,
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
return encoder_feature, content_mel_len
def _init_sum_logprobs(self):
return mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens, encoder_feature, kv_cache=self.state.kv_cache,
)
return logits, cross_qk
else:
return self.state.inference.logits(tokens, encoder_feature)
def _check_no_speech(self, logits):
if self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(
probs_at_sot[:, self.tokenizer.no_speech],
).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
return True
return False
def _suppress_blank_tokens(self, logits):
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
return logits
def _apply_token_suppression(self, logits):
if self.state.suppress_tokens:
suppress_indices = mx.array(
list(self.state.suppress_tokens), dtype=mx.int32,
)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def _update_tokens(self, current_tokens, logits, sum_logprobs):
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
def _process_cross_attention(
self, cross_attns: List, content_mel_len: int,
) -> mx.array:
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if not align_heads_in_layer:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
tmp.append(mx.concatenate(mat, axis=1))
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads
def _get_attended_frames(self, attn):
most_attended_frames = mx.argmax(attn[:, -1, :], axis=-1)
frames_np = np.array(most_attended_frames)
return frames_np.tolist(), int(frames_np[0])
def _is_special_token(self, current_tokens):
return int(np.array(current_tokens[0, -2])) >= DEC_PAD
def _rewind_tokens(self):
if len(self.state.tokens) > 0:
return mx.concatenate(self.state.tokens, axis=1)
return self.state.tokens[0]
def _tokens_to_list(self, current_tokens, start_col):
return np.array(current_tokens[0, start_col:]).tolist()
def _make_new_tokens_tensor(self, hypothesis):
new_tokens = mx.array([hypothesis], dtype=mx.int32)
return mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
def _evaluate(self, tensor):
mx.eval(tensor)

View File

@@ -5,24 +5,11 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from mlx_whisper import whisper from mlx_whisper import whisper
mlx_model_mapping = { from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx", mlx_model_mapping = MLX_MODEL_MAPPING
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder( def load_mlx_encoder(
path_or_hf_repo: str, path_or_hf_repo: str,
@@ -69,4 +56,40 @@ def load_mlx_encoder(
model.update(encoder_weights) model.update(encoder_weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model return model

View File

@@ -1,289 +1,198 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
import os
import logging import logging
import os
from typing import List
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer from whisperlivekit.backend_support import (faster_backend_available,
from .config import AlignAttConfig mlx_backend_available)
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES TOKENS_PER_SECOND,
from .whisper.timing import median_filter log_mel_spectrogram, pad_or_trim)
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter
from .align_att_base import DEC_PAD, AlignAttBase
from .beam import BeamPyTorchInference from .beam import BeamPyTorchInference
from .config import AlignAttConfig
from .decoder_state import DecoderState
from .eow_detection import fire_at_boundary, load_cif from .eow_detection import fire_at_boundary, load_cif
import os
from time import time
from .token_buffer import TokenBuffer from .token_buffer import TokenBuffer
import numpy as np
from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import *
DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if mlx_backend_available():
try: from mlx_whisper.audio import \
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
class PaddedAlignAttWhisper: if faster_backend_available():
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
USE_MLCORE = False
def load_coreml_encoder():
try:
from coremltools.models import MLModel
except ImportError:
logger.warning("coremltools is not installed")
return None
COREML_ENCODER_PATH = os.environ.get(
"MLCORE_ENCODER_PATH",
"whisperlivekit/whisper/whisper_encoder.mlpackage",
)
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
spec = _coreml_encoder.get_spec()
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
return _coreml_encoder, _coreml_input_name, _coreml_output_name
class AlignAtt(AlignAttBase):
"""
PyTorch Alignment-based Attention decoder for SimulStreaming.
Hookless — the model can be shared across multiple sessions,
with each session maintaining its own DecoderState.
"""
def __init__( def __init__(
self, self,
cfg: AlignAttConfig, cfg: AlignAttConfig,
loaded_model=None, loaded_model=None,
mlx_encoder=None, mlx_encoder=None,
fw_encoder=None, fw_encoder=None,
) -> None: ) -> None:
self.log_segments = 0
self.model = loaded_model
self.mlx_encoder = mlx_encoder self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder self.fw_encoder = fw_encoder
if fw_encoder: if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) self.fw_feature_extractor = FeatureExtractor(
feature_size=loaded_model.dims.n_mels,
)
self.coreml_encoder_tuple = None
if USE_MLCORE:
self.coreml_encoder_tuple = load_coreml_encoder()
self.use_mlcore = self.coreml_encoder_tuple is not None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Common init (sets self.model, self.cfg, decode_options, etc.)
self._base_init(cfg, loaded_model)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions( # Per-session state
language = cfg.language, self.state = DecoderState()
without_timestamps = True, self._init_state(cfg)
task=cfg.task
def _init_state(self, cfg: AlignAttConfig):
self._init_state_common(cfg)
# CIF helpers for end-of-word boundary detection
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
cfg, n_audio_state=self.model.dims.n_audio_state, device=self.model.device,
) )
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
self.l_hooks = []
# model to detect end-of-word boundary at the end of the segment # Build alignment source mapping
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, self.state.align_source = {}
n_audio_state=self.model.dims.n_audio_state, self.state.num_align_heads = 0
device=self.model.device)
# install hooks to access encoder-decoder attention
self.dec_attns = []
def layer_hook(module, net_input, net_output):
# net_output[1]: B*num_head*token_len*audio_len
t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks:
hook = b.cross_attn.register_forward_hook(layer_hook)
self.l_hooks.append(hook)
self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
# save as-is, for the first token or cross attention
self.kv_cache[module.cache_id] = net_output
else:
x = self.kv_cache[module.cache_id]
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks):
hooks = [
b.attn.key.register_forward_hook(kv_hook),
b.attn.value.register_forward_hook(kv_hook),
b.cross_attn.key.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook),
]
self.l_hooks.extend(hooks)
self.align_source = {}
self.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T: for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item() layer_rank = layer_rank.item()
heads = self.align_source.get(layer_rank, []) heads = self.state.align_source.get(layer_rank, [])
heads.append((self.num_align_heads, head_id.item())) heads.append((self.state.num_align_heads, head_id.item()))
self.align_source[layer_rank] = heads self.state.align_source[layer_rank] = heads
self.num_align_heads += 1 self.state.num_align_heads += 1
# Build suppress tokens function
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [ suppress_tokens = [
self.tokenizer.transcribe, self.tokenizer.transcribe, self.tokenizer.translate,
self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev,
self.tokenizer.sot, self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
self.tokenizer.sot_prev, ] + list(self.tokenizer.all_language_tokens)
self.tokenizer.sot_lm,
# self.tokenizer.eot
self.tokenizer.no_timestamps, # added by DM
] + list(self.tokenizer.all_language_tokens) # added by DM
if self.tokenizer.no_speech is not None: if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech) suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens))) suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}") logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens) sup_tokens = SuppressTokens(suppress_tokens)
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None) self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens() self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.first_timestamp = None
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context() self.init_context()
# decoder type: greedy or beam # Decoder type
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy": if cfg.decoder_type == "greedy":
logger.info("Using greedy decoder") logger.info("Using greedy decoder")
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot) self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
self.decoder_type = "greedy"
elif cfg.decoder_type == "beam": elif cfg.decoder_type == "beam":
self.decoder_type = "beam" logger.info("Using beam decoder")
self.inference = BeamPyTorchInference(self.model, self.initial_token_length) self.state.inference = BeamPyTorchInference(
self.inference.kv_cache = self.kv_cache self.model, self.state.initial_token_length,
)
self.state.inference.kv_cache = self.state.kv_cache
self.state.token_decoder = BeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size,
)
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) # === Abstract method implementations ===
# Tokens to carry over to next chunk for incomplete UTF-8 characters
self.pending_incomplete_tokens = []
def remove_hooks(self):
for hook in self.l_hooks:
hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
def init_context(self):
kw = {'tokenizer': self.tokenizer,
'device': self.model.device,
'prefix_token_ids': [self.tokenizer.sot_prev]}
self.context = TokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt
def init_tokens(self): def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}") logger.debug(f"init tokens, {len(self.state.segments)}")
# init tokens (mandatory prompt) self.state.initial_tokens = torch.tensor(
self.initial_tokens = torch.tensor( self.tokenizer.sot_sequence_including_notimestamps,
self.tokenizer.sot_sequence_including_notimestamps, dtype=torch.long, device=self.model.device,
dtype=torch.long, ).unsqueeze(0)
device=self.model.device).unsqueeze(0) self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.initial_token_length = self.initial_tokens.shape[1] self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) logger.debug(f"init tokens after, {len(self.state.segments)}")
# self.segments = [] self.state.tokens = [self.state.initial_tokens]
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
def trim_context(self): def init_context(self):
logger.info("Trimming context") kw = {
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) 'tokenizer': self.tokenizer,
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}") 'device': self.model.device,
logger.info(f"Context text: {self.context.as_text()}") 'prefix_token_ids': [self.tokenizer.sot_prev],
# logger.debug(f"Context tensor: {self.context.as_tensor()}") }
l = sum(t.shape[1] for t in self.tokens) + c self.state.context = TokenBuffer.empty(**kw)
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") if self.cfg.static_init_prompt is not None:
if self.cfg.static_init_prompt is None: self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
after = 0 if self.cfg.init_prompt is not None:
else: self.state.context.text += self.cfg.init_prompt
after = len(self.cfg.static_init_prompt)
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
logger.info(f"Context after trim: {self.context.text} (len: {l})")
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
if self.cfg.decoder_type == "greedy":
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
else:
logger.debug(f"Logits shape: {tokens.shape}")
logit = self.inference.logits(tokens, audio_features)
return logit
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment:")
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
self.segments = self.segments[-2:]
else:
logger.debug("removing all segments.")
self.segments = []
self.log_segments += 1
self.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True
if self.never_fire: return False
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def insert_audio(self, segment=None):
if segment is not None:
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
)
if len(self.state.tokens) > 1:
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _current_tokens(self): def _current_tokens(self):
toks = self.state.tokens
toks = self.tokens
# very first infer: duplicate start of seq to beam_size
if toks[0].shape[0] == 1: if toks[0].shape[0] == 1:
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0) toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
if not self.state.context.is_empty():
if not self.context.is_empty(): context_toks = self.state.context.as_tensor_beam(
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device) self.cfg.beam_size, device=self.model.device,
)
toks = [context_toks] + toks toks = [context_toks] + toks
# make it one tensor
if len(toks) > 1: if len(toks) > 1:
current_tokens = torch.cat(toks, dim=1) current_tokens = torch.cat(toks, dim=1)
else: else:
@@ -292,66 +201,19 @@ class PaddedAlignAttWhisper:
self.debug_print_tokens(current_tokens) self.debug_print_tokens(current_tokens)
return current_tokens return current_tokens
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
def debug_print_tokens(self, tokens): if self.state.always_fire:
for i in range(self.cfg.beam_size): return True
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist())) if self.state.never_fire:
### audio buffer
def segments_len(self):
segments_len = sum(s.shape[0] for s in self.segments) / 16000
return segments_len
def _apply_minseglen(self):
segments_len = self.segments_len()
# wait for long enough audio to start
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False return False
return True return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
def insert_audio(self, segment=None):
if segment is not None:
self.segments.append(segment)
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.cumulative_time_offset += removed_len # Track cumulative time removed
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:].tolist())
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad() @torch.no_grad()
def lang_id(self, encoder_features): def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0] n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1] x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device)
logits = self.model.logits(x, encoder_features)[:, 0] logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(self.tokenizer.all_language_tokens)] = False mask[list(self.tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf logits[:, mask] = -np.inf
@@ -360,276 +222,187 @@ class PaddedAlignAttWhisper:
language_probs = [ language_probs = [
{ {
c: language_token_probs[i, j].item() c: language_token_probs[i, j].item()
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes) for j, c in zip(
self.tokenizer.all_language_tokens,
self.tokenizer.all_language_codes,
)
} }
for i in range(n_audio) for i in range(n_audio)
] ]
single = encoder_features.ndim == 2 single = encoder_features.ndim == 2
if single: if single:
language_tokens = language_tokens[0] language_tokens = language_tokens[0]
language_probs = language_probs[0] language_probs = language_probs[0]
self._clean_cache() self._clean_cache()
return language_tokens, language_probs return language_tokens, language_probs
### transcription / translation def _concat_segments(self):
if len(self.state.segments) > 1:
return torch.cat(self.state.segments, dim=0)
return self.state.segments[0]
@torch.no_grad() def _encode(self, input_segments):
def infer(self, is_last=False): if self.use_mlcore:
new_segment = True coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
if len(self.segments) == 0: mel_padded = log_mel_spectrogram(
logger.debug("No segments, nothing to do") input_segments, n_mels=self.model.dims.n_mels,
return [] padding=N_SAMPLES, device="cpu",
if not self._apply_minseglen(): ).unsqueeze(0)
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") mel = pad_or_trim(mel_padded, N_FRAMES)
input_segments = torch.cat(self.segments, dim=0) content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
return [] mel_np = np.ascontiguousarray(mel.numpy())
ml_inputs = {coreml_input_name or "mel": mel_np}
# input_segments is concatenation of audio, it's one array coreml_outputs = coreml_encoder.predict(ml_inputs)
if len(self.segments) > 1: if coreml_output_name and coreml_output_name in coreml_outputs:
input_segments = torch.cat(self.segments, dim=0) encoder_feature_np = coreml_outputs[coreml_output_name]
else: else:
input_segments = self.segments[0] encoder_feature_np = next(iter(coreml_outputs.values()))
encoder_feature = torch.as_tensor(
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call: np.array(encoder_feature_np), device=self.device,
# logger.debug("Resetting tokenizer to auto for new sentence.") )
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder: if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments.detach(),
n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.as_tensor(mlx_encoder_feature) encoder_feature = torch.as_tensor(mlx_encoder_feature)
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
elif self.fw_encoder: elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000 audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2 content_mel_len = int(audio_length_seconds * 100) // 2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] mel_padded_2 = self.fw_feature_extractor(
waveform=input_segments.numpy(), padding=N_SAMPLES,
)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel) encoder_feature_ctranslate = self.fw_encoder.encode(mel)
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works if self.device == 'cpu':
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate) encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
try: try:
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device) encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case: except TypeError:
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device) # Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
arr = np.array(encoder_feature_ctranslate)
if arr.dtype == np.object_:
arr = np.array(arr.tolist(), dtype=np.float32)
encoder_feature = torch.as_tensor(arr, device=self.device)
else: else:
# mel + padding to 30s mel_padded = log_mel_spectrogram(
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, input_segments, n_mels=self.model.dims.n_mels,
device=self.device).unsqueeze(0) padding=N_SAMPLES, device=self.device,
# trim to 3000 ).unsqueeze(0)
mel = pad_or_trim(mel_padded, N_FRAMES) mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel) encoder_feature = self.model.encoder(mel)
end_encode = time() return encoder_feature, content_mel_len
# print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
seconds_since_start = self.segments_len() - self.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context() def _init_sum_logprobs(self):
current_tokens = self._current_tokens() return torch.zeros(self.cfg.beam_size, device=self.device)
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) if self.state.decoder_type == "greedy":
completed = False return self.model.decoder(
# punctuation_stop = False tokens, encoder_feature,
kv_cache=self.state.kv_cache,
attn_of_alignment_heads = None return_cross_attn=True,
most_attended_frame = None )
else:
token_len_before_decoding = current_tokens.shape[1] logger.debug(f"Logits shape: {tokens.shape}")
return self.state.inference.logits(
l_absolute_timestamps = [] tokens, encoder_feature, return_cross_attn=True,
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
if new_segment:
tokens_for_logits = current_tokens
else:
# only need to use the last token except in the first forward pass
tokens_for_logits = current_tokens[:,-1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
# supress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.suppress_tokens(logits)
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
align_heads_in_layer = self.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0)
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0])
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# # stripping the last token, the eot
current_tokens = current_tokens[:, :-1]
break
# for some rare cases where the attention fails
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
# TODO: check this
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
logger.debug("ommit rewinding from special tokens")
self.last_attend_frame = most_attended_frame
else:
logger.debug(
f"[rewind detected] current attention pos: {most_attended_frame}, "
f"last attention pos: {self.last_attend_frame}; omit this segment")
self.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
break
else:
self.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
# stripping the last token, the one that is attended too close to the end
current_tokens = current_tokens[:, :-1]
break
# debug print
for i in range(self.cfg.beam_size):
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
most_attended_frames[i],
current_tokens[i, -1].item(),
self.tokenizer.decode([current_tokens[i, -1].item()])
))
tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any
if self.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last: #or punctuation_stop:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
# going to truncate the tokens after the last space
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device,
)
self.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
# Skip words containing incomplete UTF-8 from client output
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text= word,
probability=0.95,
speaker=self.speaker,
detected_language=self.detected_language
).with_offset(
self.global_time_offset
) )
timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk def _check_no_speech(self, logits):
self.pending_incomplete_tokens = [] if self.tokenizer.no_speech is not None:
if split_words and replacement_char in split_words[-1]: probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
self.pending_incomplete_tokens = split_tokens[-1] no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}") if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
return True
return False
return timestamped_words def _suppress_blank_tokens(self, logits):
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
return logits
def _apply_token_suppression(self, logits):
self.state.suppress_tokens_fn(logits)
return logits
def _update_tokens(self, current_tokens, logits, sum_logprobs):
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
def _process_cross_attention(
self, cross_attns: List, content_mel_len: int,
) -> torch.Tensor:
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = len(self.model.decoder.blocks)
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if not align_heads_in_layer:
continue
attn_mat = F.softmax(attn_mat, dim=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.dim() == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0)
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
tmp.append(torch.cat(mat, dim=1))
if not tmp:
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(
attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False,
)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
return attn_of_alignment_heads
def _get_attended_frames(self, attn):
most_attended_frames = torch.argmax(attn[:, -1, :], dim=-1)
return most_attended_frames.tolist(), most_attended_frames[0].item()
def _is_special_token(self, current_tokens):
return current_tokens[0, -2].item() >= DEC_PAD
def _rewind_tokens(self):
if len(self.state.tokens) > 0:
return torch.cat(self.state.tokens, dim=1)
return self.state.tokens[0]
def _tokens_to_list(self, current_tokens, start_col):
return current_tokens[0, start_col:].flatten().tolist()
def _make_new_tokens_tensor(self, hypothesis):
return (
torch.tensor([hypothesis], dtype=torch.long)
.repeat_interleave(self.cfg.beam_size, dim=0)
.to(device=self.device)
)
def _evaluate(self, tensor):
pass # No-op for PyTorch
@torch.no_grad()
def infer(self, is_last=False):
return super().infer(is_last)

View File

@@ -1,5 +1,8 @@
import torch
import sys import sys
import torch
class TokenBuffer: class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]): def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):

View File

@@ -1,171 +0,0 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only=False,
custom_alignment_heads=None
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Any, List
from datetime import timedelta from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -8,22 +8,19 @@ def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS.""" """Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds))) return str(timedelta(seconds=int(seconds)))
@dataclass @dataclass
class TimedText: class Timed:
start: Optional[float] = 0 start: Optional[float] = 0
end: Optional[float] = 0 end: Optional[float] = 0
@dataclass
class TimedText(Timed):
text: Optional[str] = '' text: Optional[str] = ''
speaker: Optional[int] = -1 speaker: Optional[int] = -1
probability: Optional[float] = None
is_dummy: Optional[bool] = False
detected_language: Optional[str] = None detected_language: Optional[str] = None
def is_punctuation(self): def has_punctuation(self) -> bool:
return self.text.strip() in PUNCTUATION_MARKS return any(char in PUNCTUATION_MARKS for char in self.text.strip())
def overlaps_with(self, other: 'TimedText') -> bool:
return not (self.end <= other.start or other.end <= self.start)
def is_within(self, other: 'TimedText') -> bool: def is_within(self, other: 'TimedText') -> bool:
return other.contains_timespan(self) return other.contains_timespan(self)
@@ -31,27 +28,26 @@ class TimedText:
def duration(self) -> float: def duration(self) -> float:
return self.end - self.start return self.end - self.start
def contains_time(self, time: float) -> bool:
return self.start <= time <= self.end
def contains_timespan(self, other: 'TimedText') -> bool: def contains_timespan(self, other: 'TimedText') -> bool:
return self.start <= other.start and self.end >= other.end return self.start <= other.start and self.end >= other.end
def __bool__(self): def __bool__(self) -> bool:
return bool(self.text) return bool(self.text)
def __str__(self) -> str:
return str(self.text)
@dataclass() @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
probability: Optional[float] = None
corrected_speaker: Optional[int] = -1
validated_speaker: bool = False
validated_text: bool = False
validated_language: bool = False
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool:
return False
@dataclass @dataclass
class Sentence(TimedText): class Sentence(TimedText):
@@ -70,68 +66,93 @@ class Transcript(TimedText):
sep: Optional[str] = None, sep: Optional[str] = None,
offset: float = 0 offset: float = 0
) -> "Transcript": ) -> "Transcript":
"""Collapse multiple ASR tokens into a single transcript span."""
sep = sep if sep is not None else ' ' sep = sep if sep is not None else ' '
text = sep.join(token.text for token in tokens) text = sep.join(token.text for token in tokens)
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
if tokens: if tokens:
start = offset + tokens[0].start start = offset + tokens[0].start
end = offset + tokens[-1].end end = offset + tokens[-1].end
else: else:
start = None start = None
end = None end = None
return cls(start, end, text, probability=probability) return cls(start, end, text)
@dataclass @dataclass
class SpeakerSegment(TimedText): class SpeakerSegment(Timed):
"""Represents a segment of audio attributed to a specific speaker. """Represents a segment of audio attributed to a specific speaker.
No text nor probability is associated with this segment. No text nor probability is associated with this segment.
""" """
speaker: Optional[int] = -1
pass pass
@dataclass @dataclass
class Translation(TimedText): class Translation(TimedText):
pass pass
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
"""
if not self.text or not self.contains_time(cut_time):
return self, None
words = self.text.split()
num_words = len(words)
if num_words == 0:
return self, None
duration_per_word = self.duration() / num_words
cut_word_index = int((cut_time - self.start) / duration_per_word)
if cut_word_index >= num_words:
cut_word_index = num_words -1
text0 = " ".join(words[:cut_word_index])
text1 = " ".join(words[cut_word_index:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass @dataclass
class Silence(): class Silence():
duration: float start: Optional[float] = None
end: Optional[float] = None
duration: Optional[float] = None
is_starting: bool = False
has_ended: bool = False
def compute_duration(self) -> Optional[float]:
if self.start is None or self.end is None:
return None
self.duration = self.end - self.start
return self.duration
def is_silence(self) -> bool:
return True
@dataclass @dataclass
class Line(TimedText): class Segment(TimedText):
translation: str = '' """Generic contiguous span built from tokens or silence markers."""
start: Optional[float]
def to_dict(self): end: Optional[float]
_dict = { text: Optional[str]
speaker: Optional[str]
tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
def to_dict(self) -> Dict[str, Any]:
"""Serialize the segment for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1, 'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text, 'text': self.text,
'start': format_time(self.start), 'start': format_time(self.start),
@@ -142,24 +163,38 @@ class Line(TimedText):
if self.detected_language: if self.detected_language:
_dict['detected_language'] = self.detected_language _dict['detected_language'] = self.detected_language
return _dict return _dict
@dataclass
class PuncSegment(Segment):
pass
class SilentSegment(Segment):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass @dataclass
class FrontData(): class FrontData():
status: str = '' status: str = ''
error: str = '' error: str = ''
lines: list[Line] = field(default_factory=list) lines: list[Segment] = field(default_factory=list)
buffer_transcription: str = '' buffer_transcription: str = ''
buffer_diarization: str = '' buffer_diarization: str = ''
buffer_translation: str = ''
remaining_time_transcription: float = 0. remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0. remaining_time_diarization: float = 0.
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
_dict = { """Serialize the front-end data payload."""
_dict: Dict[str, Any] = {
'status': self.status, 'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)], 'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription, 'buffer_transcription': self.buffer_transcription,
'buffer_diarization': self.buffer_diarization, 'buffer_diarization': self.buffer_diarization,
'buffer_translation': self.buffer_translation,
'remaining_time_transcription': self.remaining_time_transcription, 'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization, 'remaining_time_diarization': self.remaining_time_diarization,
} }
@@ -174,13 +209,22 @@ class ChangeSpeaker:
@dataclass @dataclass
class State(): class State():
tokens: list = field(default_factory=list) """Unified state class for audio processing.
last_validated_token: int = 0
translation_validated_segments: list = field(default_factory=list) Contains both persistent state (tokens, buffers) and temporary update buffers
translation_buffer: list = field(default_factory=list) (new_* fields) that are consumed by TokensAlignment.
buffer_transcription: str = field(default_factory=Transcript) """
# Persistent state
tokens: List[ASRToken] = field(default_factory=list)
buffer_transcription: Transcript = field(default_factory=Transcript)
end_buffer: float = 0.0 end_buffer: float = 0.0
end_attributed_speaker: float = 0.0 end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0 remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0 remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None
# Temporary update buffers (consumed by TokensAlignment.update())
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
new_translation: List[Any] = field(default_factory=list)
new_diarization: List[Any] = field(default_factory=list)
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
new_translation_buffer= TimedText()

View File

@@ -0,0 +1,220 @@
from time import time
from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
SilentSegment, SpeakerSegment,
TimedText)
class TokensAlignment:
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
self.diarization = args.diarization
self._tokens_index: int = 0
self._diarization_index: int = 0
self._translation_index: int = 0
self.all_tokens: List[ASRToken] = []
self.all_diarization_segments: List[SpeakerSegment] = []
self.all_translation_segments: List[Any] = []
self.new_tokens: List[ASRToken] = []
self.new_diarization: List[SpeakerSegment] = []
self.new_translation: List[Any] = []
self.new_translation_buffer: Union[TimedText, str] = TimedText()
self.new_tokens_buffer: List[Any] = []
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
self.validated_segments: List[Segment] = []
self.current_line_tokens: List[ASRToken] = []
self.diarization_buffer: List[ASRToken] = []
self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = []
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments:
if ts.is_within(segment):
if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = []
segment_start_idx = 0
self.unvalidated_tokens += self.new_tokens
for i, token in enumerate(self.unvalidated_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i],
)
if previous_segment:
new_punc_segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
new_punc_segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
)
new_punc_segments.append(segment)
segment_start_idx = i+1
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
return new_punc_segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
"""Return the overlap duration between two timed segments."""
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
segments = []
if punctuation_segments:
segments = [punctuation_segments[0]]
for segment in punctuation_segments[1:]:
if segment.speaker == segments[-1].speaker:
if segments[-1].text:
segments[-1].text += segment.text
segments[-1].end = segment.end
else:
segments.append(segment)
return segments, diarization_buffer
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
if diarization:
segments, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
for token in self.new_tokens:
if isinstance(token, Silence):
if self.current_line_tokens:
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
else:
self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else:
self.current_line_tokens.append(token)
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if segments and segments[-1].is_silence():
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
else:
segments.append(SilentSegment(
start=current_silence.start,
end=end_silence
))
if translation:
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
return segments, diarization_buffer, self.new_translation_buffer.text

View File

@@ -1,60 +0,0 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True

View File

@@ -0,0 +1,395 @@
"""
Voxtral Mini Realtime streaming backend using HuggingFace Transformers.
Uses VoxtralRealtimeForConditionalGeneration with a background generate thread
and queue-based audio feeding for real-time streaming transcription.
Supports CUDA, CPU, and MPS devices.
"""
import logging
import queue
import sys
import threading
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
class VoxtralHFStreamingASR:
"""Voxtral model holder using HuggingFace Transformers."""
sep = " "
def __init__(self, logfile=sys.stderr, **kwargs):
import torch
from transformers import (
AutoProcessor,
VoxtralRealtimeForConditionalGeneration,
)
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
DEFAULT_MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL
t = time.time()
logger.info(f"Loading Voxtral model '{model_path}' via HF Transformers...")
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
logger.info(f"Voxtral HF model loaded in {time.time() - t:.2f}s on {self.model.device}")
self.backend_choice = "voxtral"
self.tokenizer = None # sentence tokenizer — not needed for streaming
def transcribe(self, audio):
pass
class VoxtralHFStreamingOnlineProcessor:
"""
Online processor for Voxtral streaming ASR via HuggingFace Transformers.
Uses a background thread running model.generate() with a queue-based
input_features_generator and TextIteratorStreamer for real-time output.
Each decoded token corresponds to ~80ms of audio.
"""
SAMPLING_RATE = 16000
def __init__(self, asr: VoxtralHFStreamingASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32)
processor = asr.processor
self._first_chunk_samples = processor.num_samples_first_audio_chunk
self._chunk_samples = processor.num_samples_per_audio_chunk
self._chunk_step = processor.raw_audio_length_per_tok
n_right_pad = processor.num_right_pad_tokens
if callable(n_right_pad):
n_right_pad = n_right_pad()
self._right_pad_samples = int(n_right_pad * processor.raw_audio_length_per_tok)
self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE
self._reset_state()
logger.info(
f"[voxtral-hf] Initialized. first_chunk={self._first_chunk_samples} samples, "
f"chunk={self._chunk_samples}, step={self._chunk_step}, "
f"right_pad={self._right_pad_samples}"
)
def _reset_state(self):
self._pending_audio = np.zeros(0, dtype=np.float32)
self._audio_queue: queue.Queue = queue.Queue()
self._streamer_texts: List[str] = []
self._generate_thread: Optional[threading.Thread] = None
self._generate_started = False
self._generate_finished = False
self._generate_error: Optional[Exception] = None
# Text accumulation and word extraction
self._accumulated_text = ""
self._n_text_tokens_received = 0
self._n_committed_words = 0
self._global_time_offset = 0.0
# Lock for text state accessed from both generate thread and main thread
self._text_lock = threading.Lock()
# ── Interface methods ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning(f"[voxtral-hf] process_iter exception: {e}", exc_info=True)
return [], self.end
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer."""
with self._text_lock:
text = self._accumulated_text
if not text:
return Transcript(start=None, end=None, text="")
words = text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts."""
self._drain_streamer()
words = self._flush_all_pending_words()
logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words")
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio with right-padding and stop the generate thread."""
# Add right-padding so the model can finish decoding
if self._right_pad_samples > 0:
right_pad = np.zeros(self._right_pad_samples, dtype=np.float32)
self._pending_audio = np.append(self._pending_audio, right_pad)
# Feed remaining audio
if self._generate_started and not self._generate_finished:
self._feed_pending_audio()
# Signal end of audio
self._audio_queue.put(None)
# Wait for generate to finish
if self._generate_thread is not None:
self._generate_thread.join(timeout=30.0)
elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples:
# Never started but have enough audio — start and immediately finish
self._start_generate_thread()
self._feed_pending_audio()
self._audio_queue.put(None)
if self._generate_thread is not None:
self._generate_thread.join(timeout=30.0)
self._drain_streamer()
words = self._flush_all_pending_words()
logger.info(f"[voxtral-hf] finish: flushed {len(words)} words")
return words, self.end
# ── Generate thread management ──
def _start_generate_thread(self):
"""Start model.generate() in a background thread with streaming."""
import torch
from transformers import TextIteratorStreamer
processor = self.asr.processor
model = self.asr.model
# Extract first chunk
first_chunk_audio = self._pending_audio[:self._first_chunk_samples]
self._pending_audio = self._pending_audio[self._first_chunk_samples:]
first_inputs = processor(
first_chunk_audio,
is_streaming=True,
is_first_audio_chunk=True,
return_tensors="pt",
)
first_inputs = first_inputs.to(model.device, dtype=model.dtype)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self._streamer = streamer
audio_queue = self._audio_queue
def input_features_gen():
yield first_inputs.input_features
while True:
chunk_audio = audio_queue.get()
if chunk_audio is None:
break
inputs = processor(
chunk_audio,
is_streaming=True,
is_first_audio_chunk=False,
return_tensors="pt",
)
inputs = inputs.to(model.device, dtype=model.dtype)
yield inputs.input_features
def run_generate():
try:
with torch.no_grad():
# Pass generator as input_features — the model detects GeneratorType
# and internally converts it to input_features_generator
generate_kwargs = {
k: v for k, v in first_inputs.items()
if k != "input_features"
}
model.generate(
input_features=input_features_gen(),
streamer=streamer,
**generate_kwargs,
)
except Exception as e:
logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True)
self._generate_error = e
finally:
self._generate_finished = True
self._generate_thread = threading.Thread(target=run_generate, daemon=True)
self._generate_thread.start()
self._generate_started = True
logger.info("[voxtral-hf] generate thread started")
def _feed_pending_audio(self):
"""Convert pending audio into properly-sized chunks for the generator."""
chunk_size = self._chunk_samples
step_size = self._chunk_step
while len(self._pending_audio) >= chunk_size:
chunk = self._pending_audio[:chunk_size]
self._audio_queue.put(chunk)
self._pending_audio = self._pending_audio[step_size:]
self.audio_buffer = self._pending_audio
def _drain_streamer(self):
"""Non-blocking drain of all available text from the streamer."""
if not self._generate_started:
return
text_queue = self._streamer.text_queue
while True:
try:
text_fragment = text_queue.get_nowait()
except queue.Empty:
break
# TextIteratorStreamer uses None as end-of-stream sentinel
if text_fragment is None:
self._generate_finished = True
break
if text_fragment:
with self._text_lock:
self._accumulated_text += text_fragment
self._n_text_tokens_received += 1
# ── Word extraction ──
def _pos_to_time(self, token_position: int) -> float:
"""Convert token position to seconds."""
return token_position * self._seconds_per_token + self._global_time_offset
def _extract_new_words(self) -> List[ASRToken]:
"""Extract complete words (all but the last, which may still be growing)."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_tokens = self._n_text_tokens_received
n_words_total = len(words)
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0
tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0
start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
with self._text_lock:
text = self._accumulated_text
if not text:
return []
words = text.split()
new_words: List[ASRToken] = []
n_tokens = max(self._n_text_tokens_received, 1)
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
start_time = self._pos_to_time(tok_start)
end_time = self._pos_to_time(tok_end)
text_out = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text_out))
self._n_committed_words += 1
return new_words
# ── Core processing ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# Start generate thread when enough audio is buffered
if not self._generate_started:
if len(self._pending_audio) >= self._first_chunk_samples:
self._start_generate_thread()
self._feed_pending_audio()
else:
return [], self.end
# Feed any new pending audio
if self._generate_started and not self._generate_finished:
self._feed_pending_audio()
# If generate finished unexpectedly (EOS) but new audio arrived, restart
if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples:
self._drain_streamer()
flush_words = self._flush_all_pending_words()
# Reset for new utterance
old_offset = self._global_time_offset
self._reset_state()
self._global_time_offset = old_offset
self._start_generate_thread()
self._feed_pending_audio()
return flush_words, self.end
# Drain available text from streamer
self._drain_streamer()
# Extract complete words
new_words = self._extract_new_words()
if new_words:
logger.info(f"[voxtral-hf] returning {len(new_words)} words: {[w.text for w in new_words]}")
self.buffer = []
return new_words, self.end

View File

@@ -0,0 +1,6 @@
"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
from .loader import load_voxtral_model
from .model import VoxtralMLXModel
__all__ = ["load_voxtral_model", "VoxtralMLXModel"]

View File

@@ -0,0 +1,282 @@
"""
Model weight loading for the MLX Voxtral Realtime backend.
Supports two on-disk formats:
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
with optional quantisation metadata.
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
requires weight renaming and conv-weight transposition.
The public entry point is :func:`load_voxtral_model` which returns the
model, tokenizer, and raw config dict.
"""
import json
import logging
import re
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from .model import VoxtralMLXModel
logger = logging.getLogger(__name__)
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
# ---------------------------------------------------------------------------
# Downloading
# ---------------------------------------------------------------------------
_ALLOWED_PATTERNS = [
"consolidated.safetensors",
"model*.safetensors",
"model.safetensors.index.json",
"params.json",
"config.json",
"tekken.json",
]
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
"""Download model files from HuggingFace Hub and return the local path."""
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
# ---------------------------------------------------------------------------
# Weight name remapping (Mistral → our naming)
# ---------------------------------------------------------------------------
_NAME_RULES: list[tuple[str, str]] = [
# Encoder convolutions
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
# Encoder transformer blocks
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
r"encoder.blocks.\1.self_attn.q_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
r"encoder.blocks.\1.self_attn.k_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
r"encoder.blocks.\1.self_attn.v_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
r"encoder.blocks.\1.self_attn.out_proj.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
r"encoder.blocks.\1.ffn.gate.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
r"encoder.blocks.\1.ffn.down.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
r"encoder.blocks.\1.ffn.up.\2"),
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
# Adapter
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
# Decoder embedding
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
# Decoder blocks
(r"layers\.(\d+)\.attention\.wq\.weight",
r"decoder.blocks.\1.self_attn.q_proj.weight"),
(r"layers\.(\d+)\.attention\.wk\.weight",
r"decoder.blocks.\1.self_attn.k_proj.weight"),
(r"layers\.(\d+)\.attention\.wv\.weight",
r"decoder.blocks.\1.self_attn.v_proj.weight"),
(r"layers\.(\d+)\.attention\.wo\.weight",
r"decoder.blocks.\1.self_attn.out_proj.weight"),
(r"layers\.(\d+)\.attention_norm\.weight",
r"decoder.blocks.\1.pre_attn_norm.weight"),
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
r"decoder.blocks.\1.ffn.gate.weight"),
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
r"decoder.blocks.\1.ffn.down.weight"),
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
r"decoder.blocks.\1.ffn.up.weight"),
(r"layers\.(\d+)\.ffn_norm\.weight",
r"decoder.blocks.\1.pre_ffn_norm.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
# Decoder final norm
(r"norm\.weight", r"decoder.final_norm.weight"),
]
_PREFIX_STRIP = re.compile(
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
)
def _translate_weight_name(name: str) -> str | None:
name = _PREFIX_STRIP.sub("", name)
for pattern, replacement in _NAME_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
return result
return None
def _is_conv_weight(name: str) -> bool:
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
# ---------------------------------------------------------------------------
# Converted-format weight remapping (voxmlx names → our names)
# ---------------------------------------------------------------------------
_CONVERTED_RULES: list[tuple[str, str]] = [
# Adapter
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
# Encoder transformer blocks
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
# Decoder embedding
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
# Decoder blocks
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
]
# Also remap o_proj → out_proj in both encoder and decoder
_POST_RENAME = [
(r"\.o_proj\.", r".out_proj."),
]
def _remap_converted_name(name: str) -> str:
"""Translate a converted-format weight name to our naming convention."""
for pattern, replacement in _CONVERTED_RULES:
result, n = re.subn(f"^{pattern}$", replacement, name)
if n:
name = result
break
for pattern, replacement in _POST_RENAME:
name = re.sub(pattern, replacement, name)
return name
# ---------------------------------------------------------------------------
# Loading strategies
# ---------------------------------------------------------------------------
def _has_converted_layout(path: Path) -> bool:
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
def _load_converted_weights(path: Path):
with open(path / "config.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
quant = config.get("quantization")
if quant is not None:
gs = quant["group_size"]
nn.quantize(
model,
group_size=gs,
bits=quant["bits"],
class_predicate=lambda _p, m: (
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
),
)
index_file = path / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
shard_map = json.load(f)
shard_files = sorted(set(shard_map["weight_map"].values()))
weights = {}
for sf in shard_files:
weights.update(mx.load(str(path / sf)))
else:
weights = mx.load(str(path / "model.safetensors"))
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
model.load_weights(list(remapped.items()))
mx.eval(model.parameters())
return model, config
def _load_original_weights(path: Path):
with open(path / "params.json") as f:
config = json.load(f)
model = VoxtralMLXModel(config)
raw = mx.load(str(path / "consolidated.safetensors"))
mapped: dict[str, mx.array] = {}
skipped: list[str] = []
for name, tensor in raw.items():
if name == "output.weight":
continue
new_name = _translate_weight_name(name)
if new_name is None:
skipped.append(name)
continue
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
if _is_conv_weight(new_name):
tensor = mx.swapaxes(tensor, 1, 2)
mapped[new_name] = tensor
if skipped:
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
model.load_weights(list(mapped.items()))
mx.eval(model.parameters())
return model, config
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
def _load_tokenizer(model_dir: Path):
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
"""Load a Voxtral Realtime model and its tokenizer.
Args:
path_or_id: Local directory path **or** a HuggingFace model ID.
Returns:
``(model, tokenizer, config)``
"""
p = Path(path_or_id)
if not p.exists():
p = download_weights(path_or_id)
if _has_converted_layout(p):
model, config = _load_converted_weights(p)
else:
model, config = _load_original_weights(p)
tokenizer = _load_tokenizer(p)
logger.info("Voxtral MLX model loaded from %s", p)
return model, tokenizer, config

View File

@@ -0,0 +1,534 @@
"""
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
Architecture:
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
with DelayEmbedding providing time-conditioning to the decoder.
The model supports both batch inference (full audio) and incremental streaming
(one chunk at a time with cached encoder/decoder state).
"""
import math
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class SlidingKVCache:
"""Bounded key-value cache with rotating buffer for sliding-window attention.
Uses in-place writes for single-token autoregressive steps and
concatenation for multi-token prefills. Pre-allocates in blocks of
``alloc_step`` entries to reduce repeated allocation.
"""
alloc_step = 256
def __init__(self, capacity: int):
self.capacity = capacity
self.keys = None
self.values = None
self._offset = 0
self._write_idx = 0
@property
def offset(self) -> int:
return self._offset
# -- helpers --
def _reorder(self, buf):
"""Return *buf* in temporal order (unwrap the circular buffer)."""
if self._write_idx == buf.shape[2]:
return buf
if self._write_idx < self._offset:
return mx.concatenate(
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
axis=2,
)
return buf[..., : self._write_idx, :]
def _drop_oldest(self, buf, n_drop, tail=None):
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
if tail is not None:
parts.append(tail)
return mx.concatenate(parts, axis=2)
# -- update strategies --
def _append_concat(self, k, v):
"""Multi-token update via concatenation (used during prefill)."""
if self.keys is None:
self.keys, self.values = k, v
else:
self.keys = self._reorder(self.keys)
self.values = self._reorder(self.values)
self._write_idx = self.keys.shape[2]
overflow = self._write_idx - self.capacity + 1
self.keys = self._drop_oldest(self.keys, overflow, k)
self.values = self._drop_oldest(self.values, overflow, v)
self._offset += k.shape[2]
self._write_idx = self.keys.shape[2]
return self.keys, self.values
def _write_inplace(self, k, v):
"""Single-token update via in-place write (autoregressive step)."""
B, n_heads, S, dim_k = k.shape
dim_v = v.shape[3]
prev = self._offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
):
n_new = min(self.alloc_step, self.capacity - prev)
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
self.values = mx.concatenate([self.values, fresh_v], axis=2)
else:
self.keys, self.values = fresh_k, fresh_v
self._write_idx = prev
overflow = self.keys.shape[2] - self.capacity
if overflow > 0:
self.keys = self._drop_oldest(self.keys, overflow)
self.values = self._drop_oldest(self.values, overflow)
self._write_idx = self.capacity
if self._write_idx == self.capacity:
self._write_idx = 0
self.keys[..., self._write_idx : self._write_idx + S, :] = k
self.values[..., self._write_idx : self._write_idx + S, :] = v
self._offset += S
self._write_idx += S
if self._offset < self.capacity:
return (
self.keys[..., : self._offset, :],
self.values[..., : self._offset, :],
)
return self.keys, self.values
# -- public API --
def update_and_fetch(self, k, v):
if k.shape[2] == 1:
return self._write_inplace(k, v)
return self._append_concat(k, v)
# ---------------------------------------------------------------------------
# Encoder components
# ---------------------------------------------------------------------------
class CausalConv(nn.Module):
"""1-D causal convolution (left-padded so no future leakage)."""
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
super().__init__()
self.stride = stride
self.kernel = kernel
self.left_pad = kernel - stride
self.weight = mx.zeros((channels_out, kernel, channels_in))
self.bias = mx.zeros((channels_out,))
def __call__(self, x: mx.array) -> mx.array:
if self.left_pad > 0:
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
class _EncoderSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope_theta = rope_theta
def __call__(self, x, mask, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _EncoderFFN(nn.Module):
"""SwiGLU feed-forward for encoder layers."""
def __init__(self, dim: int, hidden: int):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=True)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class _EncoderBlock(nn.Module):
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _EncoderFFN(dim, hidden)
def __call__(self, x, mask, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
x = x + self.ffn(self.pre_ffn_norm(x))
return x
class StreamingEncoder(nn.Module):
"""Causal Whisper-style encoder with two causal convolutions followed by
a stack of transformer blocks. Supports both full-sequence and
incremental (streaming) forward passes."""
def __init__(
self,
mel_channels: int = 128,
dim: int = 1280,
n_layers: int = 32,
n_heads: int = 32,
head_dim: int = 64,
hidden_dim: int = 5120,
rope_theta: float = 1e6,
sliding_window: int = 750,
):
super().__init__()
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
self.blocks = [
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
self.sliding_window = sliding_window
# -- full-sequence --
def _apply_convs(self, mel: mx.array) -> mx.array:
x = mel.T[None, :, :] # [1, T, mel_channels]
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
return x
def forward(self, mel: mx.array) -> mx.array:
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
for blk in self.blocks:
x = blk(x, mask="causal")
return self.final_norm(x)
# -- incremental (streaming) --
def forward_conv_incremental(self, x_in, tail1, tail2):
"""Process new mel frames through the two causal convs using cached tails.
Args:
x_in: [1, N, mel_channels]
tail1: [1, pad1, mel_channels] or None (first call)
tail2: [1, pad2, dim] or None (first call)
Returns:
(out, new_tail1, new_tail2)
"""
# Conv1 (kernel=3, stride=1 → left_pad=2)
if tail1 is not None:
c1_in = mx.concatenate([tail1, x_in], axis=1)
else:
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
c1_out = nn.gelu(
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
)
# Conv2 (kernel=3, stride=2 → left_pad=1)
if tail2 is not None:
c2_in = mx.concatenate([tail2, c1_out], axis=1)
else:
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
c2_out = nn.gelu(
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
)
return c2_out, new_tail1, new_tail2
def forward_transformer_incremental(self, x, cache_list):
"""Run transformer blocks with per-layer KV caches."""
for i, blk in enumerate(self.blocks):
x = blk(x, mask="causal", cache=cache_list[i])
return self.final_norm(x)
# ---------------------------------------------------------------------------
# Decoder components
# ---------------------------------------------------------------------------
class _DecoderAttention(nn.Module):
"""Grouped-query attention for the text decoder."""
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_theta = rope_theta
def __call__(self, x, mask=None, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _DecoderFFN(nn.Module):
"""SwiGLU feed-forward for decoder layers."""
def __init__(self, dim, hidden):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=False)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class AdaptiveScaling(nn.Module):
"""Small MLP that produces a multiplicative scale from the delay embedding,
used to condition the FFN on the streaming delay."""
def __init__(self, dim, bottleneck):
super().__init__()
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
def __call__(self, cond):
return self.proj_out(nn.gelu(self.proj_in(cond)))
class _DecoderBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _DecoderFFN(dim, hidden)
def __call__(self, x, delay_cond, mask=None, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
x = x + self.ffn(scaled)
return x
class TextDecoder(nn.Module):
"""Mistral-style causal language model with adaptive time-conditioning."""
def __init__(
self,
dim: int = 3072,
n_layers: int = 26,
n_heads: int = 32,
n_kv_heads: int = 8,
head_dim: int = 128,
hidden_dim: int = 9216,
vocab_size: int = 131072,
rope_theta: float = 1e6,
cond_dim: int = 32,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, dim)
self.blocks = [
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
def embed(self, token_ids: mx.array) -> mx.array:
return self.token_embedding(token_ids)
def __call__(self, x, delay_cond, mask=None, cache=None):
delay_cond = delay_cond.astype(x.dtype)
for i, blk in enumerate(self.blocks):
blk_cache = cache[i] if cache is not None else None
x = blk(x, delay_cond, mask, blk_cache)
x = self.final_norm(x)
return self.token_embedding.as_linear(x)
# ---------------------------------------------------------------------------
# Adapter & embeddings
# ---------------------------------------------------------------------------
class EncoderToDecoderAdapter(nn.Module):
"""Two-layer projection from encoder space to decoder space."""
def __init__(self, enc_dim: int, dec_dim: int):
super().__init__()
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
def __call__(self, x):
return self.linear2(nn.gelu(self.linear1(x)))
class DelayEmbedding(nn.Module):
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
vector for the decoder's adaptive scaling."""
def __init__(self, dim: int = 3072, theta: float = 10000.0):
super().__init__()
self.dim = dim
half = dim // 2
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
self._freqs = freqs
def __call__(self, delay: mx.array) -> mx.array:
t = delay.reshape(-1, 1).astype(mx.float32)
angles = t * self._freqs
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class VoxtralMLXModel(nn.Module):
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
def __init__(self, config: dict):
super().__init__()
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
audio_cfg = enc_cfg["audio_encoding_args"]
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
self.encoder = StreamingEncoder(
mel_channels=audio_cfg["num_mel_bins"],
dim=enc_cfg["dim"],
n_layers=enc_cfg["n_layers"],
n_heads=enc_cfg["n_heads"],
head_dim=enc_cfg["head_dim"],
hidden_dim=enc_cfg["hidden_dim"],
rope_theta=enc_cfg["rope_theta"],
sliding_window=enc_cfg["sliding_window"],
)
adapter_input_dim = enc_cfg["dim"] * ds_factor
decoder_dim = config["dim"]
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
self.decoder = TextDecoder(
dim=decoder_dim,
n_layers=config["n_layers"],
n_heads=config["n_heads"],
n_kv_heads=config["n_kv_heads"],
head_dim=config["head_dim"],
hidden_dim=config["hidden_dim"],
vocab_size=config["vocab_size"],
rope_theta=config["rope_theta"],
cond_dim=cond_bottleneck,
)
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
self.ds_factor = ds_factor
# -- batch encode --
def encode(self, mel: mx.array) -> mx.array:
T = mel.shape[1]
if T % 2 != 0:
mel = mel[:, 1:]
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
h = h[0]
n = h.shape[0]
trim = n % self.ds_factor
if trim:
h = h[trim:]
n = h.shape[0]
h = h.reshape(n // self.ds_factor, -1)
return self.adapter(h)
# -- incremental encode --
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
"""Incrementally encode new mel frames.
Returns:
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
"""
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
if enc_cache is None:
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
x = self.encoder.forward_transformer_incremental(x, enc_cache)
x = x[0] # [N, enc_dim]
if ds_remainder is not None:
x = mx.concatenate([ds_remainder, x])
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
if n_full == 0:
return None, conv_tail1, conv_tail2, enc_cache, x
leftover = x[n_full:] if x.shape[0] > n_full else None
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
# -- decode --
def decode(self, embeddings, delay_cond, mask=None, cache=None):
return self.decoder(embeddings, delay_cond, mask, cache)

View File

@@ -0,0 +1,202 @@
"""
Mel spectrogram computation for Voxtral Realtime.
Provides both a full-audio function and an incremental streaming variant
that maintains overlap state between calls. The DFT is computed via
matrix multiplication in MLX — no external FFT dependency required.
"""
import math
import mlx.core as mx
import numpy as np
# Audio / mel constants matching the Voxtral Realtime model expectations.
SAMPLE_RATE = 16_000
WINDOW_SIZE = 400 # n_fft
HOP = 160
MEL_BANDS = 128
MEL_MAX = 1.5 # global log-mel normalisation ceiling
# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
# Padding tokens used by the model prompt structure.
LEFT_PAD_TOKENS = 32
RIGHT_PAD_TOKENS = 17
# ---------------------------------------------------------------------------
# Slaney mel filterbank
# ---------------------------------------------------------------------------
def _build_slaney_filterbank(
sr: int = SAMPLE_RATE,
n_fft: int = WINDOW_SIZE,
n_mels: int = MEL_BANDS,
lo_hz: float = 0.0,
hi_hz: float = 8000.0,
) -> np.ndarray:
"""Compute a Slaney-normalised triangular mel filterbank.
Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
"""
def _hz2mel(f):
threshold = 1000.0
base_mel = 15.0
log_coeff = 27.0 / np.log(6.4)
mel = 3.0 * f / 200.0
if isinstance(f, np.ndarray):
above = f >= threshold
mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
elif f >= threshold:
mel = base_mel + np.log(f / threshold) * log_coeff
return mel
def _mel2hz(m):
threshold = 1000.0
base_mel = 15.0
log_coeff = np.log(6.4) / 27.0
hz = 200.0 * m / 3.0
above = m >= base_mel
hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
return hz
n_bins = n_fft // 2 + 1
fft_hz = np.linspace(0, sr / 2, n_bins)
mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = _mel2hz(mel_pts)
diffs = np.diff(hz_pts)
slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
rising = -slopes[:, :-2] / diffs[:-1]
falling = slopes[:, 2:] / diffs[1:]
fb = np.maximum(0.0, np.minimum(rising, falling))
# Slaney area normalisation
widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
fb *= np.expand_dims(widths, 0)
return fb.T.astype(np.float32)
_CACHED_FILTERS: mx.array | None = None
def _mel_filters() -> mx.array:
global _CACHED_FILTERS
if _CACHED_FILTERS is None:
_CACHED_FILTERS = mx.array(_build_slaney_filterbank())
return _CACHED_FILTERS
# ---------------------------------------------------------------------------
# DFT helpers
# ---------------------------------------------------------------------------
def _hann_window() -> mx.array:
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
def _dft_matrices():
"""Pre-compute the real / imaginary DFT basis matrices."""
n_bins = WINDOW_SIZE // 2 + 1
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
return mx.cos(phase), mx.sin(phase)
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
"""Frame *audio* using the Hann window and compute power spectrogram."""
n_bins = WINDOW_SIZE // 2 + 1
n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
if n_frames <= 0:
return mx.zeros((0, n_bins))
offsets = (mx.arange(n_frames) * HOP)[:, None]
indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
windowed = audio[indices] * window[None, :]
dft_re, dft_im = _dft_matrices()
real_part = windowed @ dft_re.T
imag_part = windowed @ dft_im.T
return real_part ** 2 + imag_part ** 2
def _apply_mel_and_log(power: mx.array) -> mx.array:
"""Convert a power spectrogram to log-mel and normalise."""
mel = power @ _mel_filters().T
log_mel = mx.log10(mx.maximum(mel, 1e-10))
log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
return (log_mel + 4.0) / 4.0
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def compute_mel(audio: np.ndarray) -> mx.array:
"""Compute log-mel spectrogram for a complete audio signal.
Args:
audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
Returns:
``[MEL_BANDS, T]`` MLX array.
"""
x = mx.array(audio)
pad = WINDOW_SIZE // 2
x = mx.pad(x, [(pad, pad)])
window = _hann_window()
power = _stft_frames(x, window)
# Drop last frame to match reference STFT behaviour
power = power[:-1]
return _apply_mel_and_log(power).T
def compute_mel_streaming(
chunk: np.ndarray,
overlap: np.ndarray | None,
) -> tuple[mx.array, np.ndarray]:
"""Incrementally compute log-mel for a new audio chunk.
Args:
chunk: New audio samples (float32 numpy).
overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
previous call, or *None* on the first call (uses zero-padding).
Returns:
``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
*new_overlap* is the 240-sample tail for the next call.
"""
tail_len = WINDOW_SIZE - HOP # 240
if overlap is not None:
combined = np.concatenate([overlap, chunk])
else:
combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
new_overlap = combined[-tail_len:].copy()
x = mx.array(combined)
window = _hann_window()
power = _stft_frames(x, window)
if power.shape[0] == 0:
return mx.zeros((MEL_BANDS, 0)), new_overlap
return _apply_mel_and_log(power).T, new_overlap
def pad_audio(
audio: np.ndarray,
n_left: int = LEFT_PAD_TOKENS,
n_right: int = RIGHT_PAD_TOKENS,
) -> np.ndarray:
"""Pad audio with silence for batch (non-streaming) inference."""
left = n_left * SAMPLES_PER_TOKEN
align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
right = align + n_right * SAMPLES_PER_TOKEN
return np.pad(audio, (left, right))

View File

@@ -0,0 +1,521 @@
"""
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
(streaming processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
Unlike the HuggingFace backend, this runs the full inference loop in-process
(no background thread / queue) — MLX operations on Apple Silicon are fast
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import (
SAMPLES_PER_TOKEN,
LEFT_PAD_TOKENS,
RIGHT_PAD_TOKENS,
compute_mel_streaming,
)
logger = logging.getLogger(__name__)
# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
return ids, n_delay
# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------
class VoxtralMLXASR:
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
it alive for the lifetime of the server."""
sep = " "
SAMPLING_RATE = 16_000
def __init__(self, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL_ID
t0 = time.time()
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
self.backend_choice = "voxtral-mlx"
def transcribe(self, audio):
pass # all work happens in the online processor
# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------
class VoxtralMLXOnlineProcessor:
"""Streaming processor that incrementally encodes audio and decodes text
using the MLX Voxtral model.
Lifecycle (called by ``AudioProcessor.transcription_processor``):
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
... repeat ...
start_silence() / end_silence()
finish()
"""
SAMPLING_RATE = 16_000
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: list = []
self.audio_buffer = np.array([], dtype=np.float32)
self._model = asr.model
self._tokenizer = asr.tokenizer
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
self._prefix_len = len(self._prompt_ids)
self._delay_cond = self._model.delay_embedding(
mx.array([self._n_delay], dtype=mx.float32)
)
mx.eval(self._delay_cond)
self._prompt_embeds = self._model.decoder.embed(
mx.array([self._prompt_ids])
)[0] # [prefix_len, dim]
mx.eval(self._prompt_embeds)
self._eos_id = self._tokenizer.eos_id
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
# The streaming model has an inherent delay: text for audio at position P
# is generated at decoder position P + n_delay. Compensate timestamps.
self._delay_secs = self._n_delay * self._secs_per_token
self._reset_state()
# -- state management --
def _reset_state(self):
"""Reset all incremental state for a fresh utterance."""
# Audio accumulation
self._pending = np.zeros(0, dtype=np.float32)
# Mel overlap
self._mel_overlap: np.ndarray | None = None
# Encoder incremental state
self._conv_tail1 = None
self._conv_tail2 = None
self._enc_cache = None
self._ds_remainder = None
# Audio embeddings not yet decoded
self._audio_embeds: mx.array | None = None
# Decoder state
self._dec_cache: list[SlidingKVCache] | None = None
self._last_token: mx.array | None = None
# Bookkeeping
self._samples_encoded = 0
self._positions_decoded = 0
self._prefilled = False
self._first_chunk = True
# Text state
self._full_text = ""
self._n_text_tokens = 0
self._n_committed_words = 0
self._time_offset = 0.0
# Per-word audio position tracking: decoder position (relative to prefix)
# where each word in _full_text started and ended
self._word_audio_starts: list[int] = [] # audio pos where word i started
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
# -- audio ingestion --
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending = np.append(self._pending, audio)
self.audio_buffer = self._pending
# -- core processing --
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._step(is_last)
except Exception as e:
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
return [], self.end
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# 1. Encode any new audio
self._encode_pending()
if self._audio_embeds is None:
return [], self.end
# 2. Compute how many positions we can safely decode
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0:
return [], self.end
# 3. Prefill if needed
if not self._prefilled:
if self._positions_decoded + n_available < self._prefix_len:
return [], self.end
self._do_prefill()
# Re-check after consuming prefix embeddings
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
n_decodable = min(n_available, total_safe - self._positions_decoded)
if n_decodable <= 0 or self._audio_embeds is None:
return [], self.end
# 4. Decode available positions
hit_eos = self._decode_positions(n_decodable)
if hit_eos:
# Flush words, reset for next utterance
words = self._flush_all_words()
logger.debug(
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
"samples_encoded=%d (%.2fs), text='%s'",
len(words), self._samples_encoded,
self._samples_encoded / self.SAMPLING_RATE,
self._full_text[-60:] if self._full_text else "",
)
saved_offset = self._time_offset
self._reset_state()
self._time_offset = saved_offset
return words, self.end
# 5. Extract committed words (all but the last, which may still grow)
return self._extract_committed_words(), self.end
def _encode_pending(self):
"""Feed pending audio through the incremental encoder."""
available = len(self._pending)
if available < SAMPLES_PER_TOKEN:
return
if self._first_chunk:
# First chunk: prepend silence for left-padding
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
chunk = np.concatenate([left_pad, self._pending[:n_take]])
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
self._first_chunk = False
else:
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
chunk = self._pending[:n_take]
self._pending = self._pending[n_take:]
self._samples_encoded += n_take
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
self._model.encode_incremental(
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
)
)
if embeds is not None:
mx.eval(embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
else:
self._audio_embeds = embeds
self.audio_buffer = self._pending
def _do_prefill(self):
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
n_dec_layers = len(self._model.decoder.blocks)
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
self._last_token = self._sample(logits)
mx.async_eval(self._last_token)
# Remove consumed prefix embeddings
self._audio_embeds = self._audio_embeds[self._prefix_len :]
if self._audio_embeds.shape[0] == 0:
self._audio_embeds = None
self._positions_decoded = self._prefix_len
self._prefilled = True
def _decode_positions(self, n: int) -> bool:
"""Autoregressively decode *n* positions. Returns True on EOS."""
base_pos = self._positions_decoded # absolute position before this batch
for i in range(n):
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
next_tok = self._sample(logits)
mx.async_eval(next_tok)
token_id = self._last_token.item()
if token_id == self._eos_id:
# Close the current word if one is being built
if self._current_word_pos is not None:
self._word_audio_ends.append(base_pos + i - self._prefix_len)
self._current_word_pos = None
self._trim_embeds(i)
self._positions_decoded += i
return True
text = self._tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
audio_pos = base_pos + i - self._prefix_len
# Detect word boundary: new word starts with space or is the very first text
if text.lstrip() != text or not self._full_text:
# Close previous word if exists
if self._current_word_pos is not None:
self._word_audio_ends.append(audio_pos)
# Start new word
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
elif self._current_word_pos is None:
# First token of first word (no leading space)
self._word_audio_starts.append(audio_pos)
self._current_word_pos = audio_pos
self._full_text += text
self._n_text_tokens += 1
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._last_token = next_tok
self._positions_decoded += n
self._trim_embeds(n)
return False
def _trim_embeds(self, n_consumed: int):
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
def _sample(self, logits: mx.array) -> mx.array:
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
# -- word extraction --
def _audio_pos_to_time(self, pos: int) -> float:
"""Convert an audio position (relative to prefix end) to seconds."""
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
"""Compute (start, end) time for a word using tracked word positions."""
starts = self._word_audio_starts
ends = self._word_audio_ends
if not starts:
return self._time_offset, self._time_offset
# Get start position for this word
if word_idx < len(starts):
t0 = self._audio_pos_to_time(starts[word_idx])
else:
# Fallback: estimate from last known position
last_pos = ends[-1] if ends else starts[-1]
t0 = self._audio_pos_to_time(last_pos + 1)
# Get end position: use the start of the next word, or the end of this word
if word_idx + 1 < len(starts):
t1 = self._audio_pos_to_time(starts[word_idx + 1])
elif word_idx < len(ends):
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
else:
# Last word, still being built: use last known position + 1 token
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
t1 = self._audio_pos_to_time(last_pos + 1)
return t0, t1
def _extract_committed_words(self) -> List[ASRToken]:
"""Return complete words (all except the last which may still grow)."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while len(words) > self._n_committed_words + 1:
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
def _flush_all_words(self) -> List[ASRToken]:
"""Flush every word including the last partial one."""
if not self._full_text:
return []
words = self._full_text.split()
tokens: List[ASRToken] = []
n_total = max(len(words), 1)
while self._n_committed_words < len(words):
w = words[self._n_committed_words]
idx = self._n_committed_words
t0, t1 = self._word_time_range(idx, n_total)
label = w if idx == 0 else " " + w
tokens.append(ASRToken(start=t0, end=t1, text=label))
self._n_committed_words += 1
return tokens
# -- interface methods --
def get_buffer(self) -> Transcript:
if not self._full_text:
return Transcript(start=None, end=None, text="")
words = self._full_text.split()
remaining = words[self._n_committed_words :]
if remaining:
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
words = self._flush_all_words()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
logger.debug(
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
len(self._pending),
self._audio_embeds.shape if self._audio_embeds is not None else None,
self._samples_encoded,
self._positions_decoded,
self._prefilled,
self._full_text[-80:] if self._full_text else "",
)
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = len(self._pending) % SAMPLES_PER_TOKEN
if remainder > 0:
align_pad = SAMPLES_PER_TOKEN - remainder
else:
align_pad = 0
# Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
if total_pad > 0:
self._pending = np.append(
self._pending, np.zeros(total_pad, dtype=np.float32)
)
# Encode remaining audio (including right-padding)
self._encode_pending()
logger.debug(
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
self._audio_embeds.shape if self._audio_embeds is not None else None,
len(self._pending),
)
hit_eos = False
# Decode everything that's left from right-padding
if self._audio_embeds is not None and self._prefilled:
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
logger.debug(
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
hit_eos, self._full_text[-80:] if self._full_text else "",
)
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
# Check if this starts a new word
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
words = self._flush_all_words()
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
return words, self.end

View File

@@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5):
import os import os
import tempfile import tempfile
import urllib.request import urllib.request
import librosa import librosa
if warmup_file == "": if warmup_file == "":

View File

@@ -490,6 +490,11 @@ label {
margin-left: 4px; margin-left: 4px;
} }
.buffer_translation {
color: #a0a0a0;
margin-left: 6px;
}
.spinner { .spinner {
display: inline-block; display: inline-block;
width: 8px; width: 8px;

View File

@@ -232,10 +232,11 @@ function setupWebSocket() {
if (waitingForStop) { if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed."; statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) { if (lastReceivedData) {
renderLinesWithBuffer( renderLinesWithBuffer(
lastReceivedData.lines || [], lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "", lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "", lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -281,6 +282,7 @@ function setupWebSocket() {
lastReceivedData.lines || [], lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "", lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "", lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
0, 0,
0, 0,
true true
@@ -301,6 +303,7 @@ function setupWebSocket() {
lines = [], lines = [],
buffer_transcription = "", buffer_transcription = "",
buffer_diarization = "", buffer_diarization = "",
buffer_translation = "",
remaining_time_transcription = 0, remaining_time_transcription = 0,
remaining_time_diarization = 0, remaining_time_diarization = 0,
status = "active_transcription", status = "active_transcription",
@@ -310,6 +313,7 @@ function setupWebSocket() {
lines, lines,
buffer_diarization, buffer_diarization,
buffer_transcription, buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
false, false,
@@ -323,6 +327,7 @@ function renderLinesWithBuffer(
lines, lines,
buffer_diarization, buffer_diarization,
buffer_transcription, buffer_transcription,
buffer_translation,
remaining_time_diarization, remaining_time_diarization,
remaining_time_transcription, remaining_time_transcription,
isFinalizing = false, isFinalizing = false,
@@ -341,6 +346,7 @@ function renderLinesWithBuffer(
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })), lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
buffer_transcription: buffer_transcription || "", buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "", buffer_diarization: buffer_diarization || "",
buffer_translation: buffer_translation,
status: current_status, status: current_status,
showLoading, showLoading,
showTransLag, showTransLag,
@@ -385,12 +391,11 @@ function renderLinesWithBuffer(
if (idx === lines.length - 1) { if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) { if (!isFinalizing && item.speaker !== -2) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1( speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription remaining_time_transcription
)}</span>s</span></span>`; )}</span>s</span></span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) { if (buffer_diarization && remaining_time_diarization) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1( speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization remaining_time_diarization
)}</span>s</span></span>`; )}</span>s</span></span>`;
@@ -415,13 +420,22 @@ function renderLinesWithBuffer(
} }
} }
} }
let translationContent = "";
if (item.translation) { if (item.translation) {
translationContent += item.translation.trim();
}
if (idx === lines.length - 1 && buffer_translation) {
const bufferPiece = isFinalizing
? buffer_translation
: `<span class="buffer_translation">${buffer_translation}</span>`;
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
}
if (translationContent.trim().length > 0) {
currentLineText += ` currentLineText += `
<div> <div>
<div class="label_translation"> <div class="label_translation">
${translationIcon} ${translationIcon}
<span>${item.translation}</span> <span class="translation_text">${translationContent}</span>
</div> </div>
</div>`; </div>`;
} }

View File

@@ -1,6 +1,6 @@
import logging
import importlib.resources as resources
import base64 import base64
import importlib.resources as resources
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -96,11 +96,13 @@ def get_inline_ui_html():
if __name__ == '__main__': if __name__ == '__main__':
import pathlib
import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
import uvicorn
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg import whisperlivekit.web as webpkg
app = FastAPI() app = FastAPI()

View File

@@ -0,0 +1,642 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from tqdm import tqdm
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
pad_or_trim)
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models/MLX models.
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
if os.path.isfile(safe_path):
try:
from safetensors.torch import load_file
except ImportError as exc:
raise ImportError(
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
) from exc
return load_file(safe_path)
if os.path.isfile(bin_path):
return torch.load(bin_path, map_location="cpu")
raise FileNotFoundError(
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
)
def _collapse_hf_module_name(module: str):
if module.startswith("base_model."):
module = module[len("base_model.") :]
if module.startswith("model.model."):
module = module[len("model.") :]
if not module.startswith("model."):
module = f"model.{module}"
return module
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
"""
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
If lora_path is a local directory containing adapter files, returns it as-is.
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
"""
if not lora_path:
return None
# Check if it's already a valid local path
if os.path.isdir(lora_path):
config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.isfile(config_path):
return lora_path
# Try to download from HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
return local_path
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
)
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path:
return
# Resolve path (handles HuggingFace Hub download)
lora_path = _resolve_lora_path(lora_path)
if not lora_path:
return
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
with open(config_path, "r", encoding="utf-8") as handle:
config = json.load(handle)
if config.get("peft_type") != "LORA":
raise ValueError("Only LoRA adapters are supported.")
r = config.get("r")
alpha = config.get("lora_alpha") or config.get("alpha")
if not r or not alpha:
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
scaling = alpha / r
adapter_state = _load_lora_state(lora_path)
lora_layers: Dict[str, Dict[str, Tensor]] = {}
for key, tensor in adapter_state.items():
if key.endswith("lora_A.weight"):
module = key[: -len(".lora_A.weight")]
lora_layers.setdefault(module, {})["A"] = tensor
elif key.endswith("lora_B.weight"):
module = key[: -len(".lora_B.weight")]
lora_layers.setdefault(module, {})["B"] = tensor
if not lora_layers:
raise ValueError(f"No LoRA tensors found in {lora_path}")
for module, parts in lora_layers.items():
if "A" not in parts or "B" not in parts:
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
hf_module = _collapse_hf_module_name(module)
hf_weight_key = f"{hf_module}.weight"
delta = parts["B"] @ parts["A"]
delta = delta * scaling
converted = _convert_hf_state_dict({hf_weight_key: delta})
if not converted:
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
target_name, delta_tensor = next(iter(converted.items()))
if target_name not in state_dict:
raise KeyError(
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
)
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
)
def _load_checkpoint(
file_path: Union[str, Path],
device: str,
in_memory: bool = False,
checkpoint_bytes: Optional[bytes] = None,
) -> Dict[str, torch.Tensor]:
"""
Load a checkpoint from a single file.
Handles .pt, .bin, and .safetensors formats.
"""
if checkpoint_bytes is not None:
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
)
return load_file(str(file_path), device=device)
else:
if in_memory:
with open(file_path, "rb") as f:
checkpoint_bytes = f.read()
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
else:
with open(file_path, "rb") as fp:
return torch.load(fp, map_location=device)
def _load_sharded_checkpoint(
shard_files: List[Path],
device: str,
) -> Dict[str, torch.Tensor]:
"""
Load a sharded checkpoint (multiple .safetensors or .bin files).
Merges all shards into a single state dict.
"""
merged_state_dict = {}
first_suffix = shard_files[0].suffix.lower()
if first_suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
)
for shard_path in shard_files:
shard_dict = load_file(str(shard_path), device=device)
merged_state_dict.update(shard_dict)
else:
for shard_path in shard_files:
with open(shard_path, "rb") as fp:
shard_dict = torch.load(fp, map_location=device)
if isinstance(shard_dict, dict):
merged_state_dict.update(shard_dict)
return merged_state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only: bool = False,
custom_alignment_heads: Optional[str] = None,
lora_path: Optional[str] = None,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
or a sharded model directory with files like model-00001-of-00002.safetensors.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
lora_path: str
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
from whisperlivekit.model_paths import detect_model_format
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
checkpoint = None
model_path_for_config = name # Used to find config.json for dims inference
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
if in_memory:
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
else:
checkpoint = _load_checkpoint(checkpoint_file, device)
elif os.path.isfile(name):
if in_memory:
with open(name, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(name, device)
model_path_for_config = name
elif os.path.isdir(name):
model_info = detect_model_format(name)
if not model_info.has_pytorch:
raise RuntimeError(
f"No PyTorch checkpoint found in directory {name}. "
f"Expected .pt, .bin, or .safetensors file(s)."
)
if model_info.is_sharded:
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
else:
single_file = model_info.pytorch_files[0]
if in_memory:
with open(single_file, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(single_file, device)
model_path_for_config = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(model_path_for_config)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(state_dict)
if alignment_heads is not None:
if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device)
def convert_encoder_to_coreml(
model_name = "base",
output_path= "whisper_encoder.mlpackage",
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
precision = "float16",
):
import coremltools as ct
model = load_model(model_name, device="cpu", decoder_only=False)
encoder = model.encoder.eval().cpu()
dummy_input = torch.randn(
1,
model.dims.n_mels,
dummy_frames,
dtype=next(encoder.parameters()).dtype,
)
with torch.no_grad():
traced_encoder = torch.jit.trace(encoder, dummy_input)
precision_map = {
"float16": ct.precision.FLOAT16,
"fp16": ct.precision.FLOAT16,
"float32": ct.precision.FLOAT32,
"fp32": ct.precision.FLOAT32,
}
coreml_precision = precision_map[precision.lower()]
mlmodel = ct.convert(
traced_encoder,
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
convert_to= "mlprogram",
compute_precision=coreml_precision,
)
output_path = Path(output_path)
mlmodel.save(str(output_path))
return output_path
# if __name__ == "__main__":
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")

Some files were not shown because too many files have changed in this diff Show More