Compare commits
337 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bc0937c46 | ||
|
|
929cf7a26b | ||
|
|
abfaf06203 | ||
|
|
d1fe932241 | ||
|
|
c112ceffb6 | ||
|
|
4917406e06 | ||
|
|
b63f54e838 | ||
|
|
c56a53fbf4 | ||
|
|
66e58624b9 | ||
|
|
9366e067f9 | ||
|
|
866c25670c | ||
|
|
2553ef283e | ||
|
|
73e7fafc48 | ||
|
|
bbcebcb1fe | ||
|
|
4bb58dc7aa | ||
|
|
27ca028479 | ||
|
|
d24805cc18 | ||
|
|
994ce21365 | ||
|
|
132823dc09 | ||
|
|
d6d8c2635f | ||
|
|
8fedeb9fed | ||
|
|
b1fc23807a | ||
|
|
10c4e5f730 | ||
|
|
c76b2ef2c6 | ||
|
|
4b2377c243 | ||
|
|
a4da246ea5 | ||
|
|
9b2c3ee844 | ||
|
|
83d0fa3fac | ||
|
|
5a12c627b4 | ||
|
|
f5eee67b11 | ||
|
|
4a6868e3e1 | ||
|
|
3c15246fc0 | ||
|
|
d337248fda | ||
|
|
b8d9d7d289 | ||
|
|
4c7706e2cf | ||
|
|
7f3a3df620 | ||
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af | ||
|
|
6ae545bcb1 | ||
|
|
04980d3f5e | ||
|
|
79a705c969 | ||
|
|
34e4abd455 | ||
|
|
d59ddbaeae | ||
|
|
4dd66e7766 | ||
|
|
3db5d81a20 | ||
|
|
b67ddea494 | ||
|
|
3192553e20 | ||
|
|
f379a243fe | ||
|
|
ec09898a9f | ||
|
|
befbae56c7 | ||
|
|
bbd4fd6cff | ||
|
|
28985962a0 | ||
|
|
a38c103fcd | ||
|
|
4d2ffb24f8 | ||
|
|
1bbbb7903c | ||
|
|
bcffdbc6b3 | ||
|
|
80b77998f9 | ||
|
|
d310f7e25f | ||
|
|
8d9be88fe6 | ||
|
|
16461052ed | ||
|
|
5491dbd824 | ||
|
|
13401ffe24 | ||
|
|
7108d2ddc5 | ||
|
|
a732e0903e | ||
|
|
0491681be4 | ||
|
|
ffe5284764 | ||
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
12a69205ed | ||
|
|
1f684cdd97 | ||
|
|
3467109668 | ||
|
|
971f8473eb | ||
|
|
8434ef5efc | ||
|
|
290470dd60 | ||
|
|
425ac7b51d | ||
|
|
0382cfbeba | ||
|
|
9b1e061b32 | ||
|
|
b4abc158b9 | ||
|
|
5832d7433d | ||
|
|
3736458503 | ||
|
|
374618e050 | ||
|
|
543972ef38 | ||
|
|
73f36cc0ef | ||
|
|
a7db39d999 | ||
|
|
a153e11fe0 | ||
|
|
ca6f9246cc | ||
|
|
d080d675a8 | ||
|
|
40bff38933 | ||
|
|
2fe3ca0188 | ||
|
|
545ea15c9a | ||
|
|
8cbaeecc75 | ||
|
|
70e854b346 | ||
|
|
d55490cd27 | ||
|
|
1fa9e1f656 | ||
|
|
994f30e1ed | ||
|
|
b22478c0b4 | ||
|
|
94c34efd90 | ||
|
|
32099b9275 | ||
|
|
9fc6654a4a | ||
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
cd9a32a36b | ||
|
|
6caf3e0485 | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
4d7c487614 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 | ||
|
|
2a27d2030a | ||
|
|
cd160caaa1 | ||
|
|
d27b5eb23e | ||
|
|
f9d704a900 | ||
|
|
2f6e00f512 | ||
|
|
5aa312e437 | ||
|
|
ebaf36a8be | ||
|
|
babe93b99a | ||
|
|
a4e9f3cab7 | ||
|
|
b06866877a | ||
|
|
967cdfebc8 | ||
|
|
3c11c60126 | ||
|
|
2963e8a757 | ||
|
|
cb2d4ea88a | ||
|
|
add7ea07ee | ||
|
|
da8726b2cb | ||
|
|
3358877054 | ||
|
|
1f7798c7c1 | ||
|
|
c7b3bb5e58 | ||
|
|
f661f21675 | ||
|
|
b6164aa59b | ||
|
|
4209d7f7c0 | ||
|
|
334b338ab0 | ||
|
|
72f33be6f2 | ||
|
|
84890b8e61 | ||
|
|
c6668adcf3 | ||
|
|
a178ed5c22 | ||
|
|
7601c74c9c | ||
|
|
fad9ee4d21 | ||
|
|
d1a9913c47 | ||
|
|
e4ca2623cb | ||
|
|
9c1bf37960 | ||
|
|
f46528471b | ||
|
|
191680940b | ||
|
|
ee02afec56 | ||
|
|
a458028de2 | ||
|
|
abd8f2c269 | ||
|
|
f3ad4e39e4 | ||
|
|
e0a5cbf0e7 | ||
|
|
953697cd86 | ||
|
|
3bd2122eb4 | ||
|
|
50b0527858 | ||
|
|
b044fcdec2 | ||
|
|
b0508fcf2c | ||
|
|
ce89b0aebc | ||
|
|
d5008ed828 | ||
|
|
d467716e26 | ||
|
|
199e21b3ef | ||
|
|
1d926f2e67 | ||
|
|
4a71a391b8 | ||
|
|
d3ed4e46e2 | ||
|
|
057a1026d7 | ||
|
|
1ba171a58d | ||
|
|
1adac67155 | ||
|
|
42be1a3773 | ||
|
|
0a49fafa0d | ||
|
|
4a5d5e1f3b | ||
|
|
583a2ec2e4 | ||
|
|
19765e89e9 | ||
|
|
9895bc83bf | ||
|
|
ab98c31f16 | ||
|
|
f9c9c4188a | ||
|
|
719e8b1a20 | ||
|
|
f1b47178d8 | ||
|
|
59db08e961 | ||
|
|
6fc20b9562 | ||
|
|
fac8659161 | ||
|
|
4d9332ce7d | ||
|
|
62444ce746 | ||
|
|
2431a6bf91 | ||
|
|
d1263e7228 | ||
|
|
30ddd522a4 | ||
|
|
635bace09e | ||
|
|
f1113e3eb0 | ||
|
|
cc5f819ce7 | ||
|
|
82cd24bb75 | ||
|
|
d45c397c6a | ||
|
|
45bf3f57d7 | ||
|
|
1d88ba9d69 | ||
|
|
c0965c6c31 | ||
|
|
34ddd2ac02 | ||
|
|
345d781e97 | ||
|
|
28cf831701 | ||
|
|
60c62f8f84 | ||
|
|
7faa21f95f | ||
|
|
4e9f951551 | ||
|
|
870141298c | ||
|
|
872faa422a | ||
|
|
fc9cb66813 | ||
|
|
a175d1a327 | ||
|
|
6206fff118 | ||
|
|
b5067249c0 | ||
|
|
f4f9831d39 | ||
|
|
254faaf64c | ||
|
|
8e7aea4fcf | ||
|
|
270faf2069 | ||
|
|
b7c1cc77cc | ||
|
|
9a45ec221c | ||
|
|
3e13ee6fc3 | ||
|
|
b7d20a0ff0 | ||
|
|
c1bb9c2bde | ||
|
|
11e9def0b2 | ||
|
|
3104f40f6e | ||
|
|
e9b4ceeee5 | ||
|
|
437641fb43 | ||
|
|
bfd60b3921 | ||
|
|
1e67bf97f0 | ||
|
|
c21d2302e7 | ||
|
|
4ed62e181d | ||
|
|
52a755a08c | ||
|
|
9a8d3cbd90 | ||
|
|
b101ce06bd | ||
|
|
c83fd179a8 | ||
|
|
5258305745 | ||
|
|
ce781831ee | ||
|
|
58297daf6d | ||
|
|
3393a08f7e | ||
|
|
5b2ddeccdb | ||
|
|
26cc1072dd | ||
|
|
12973711f6 | ||
|
|
909ac9dd41 | ||
|
|
d94a07d417 | ||
|
|
b32dd8bfc4 | ||
|
|
9feb0e597b | ||
|
|
9dab84a573 | ||
|
|
d089c7fce0 | ||
|
|
253a080df5 | ||
|
|
0c6e4b2aee | ||
|
|
e14bbde77d | ||
|
|
7496163467 | ||
|
|
696a94d1ce | ||
|
|
2699b0974c | ||
|
|
90c0250ba4 | ||
|
|
eb96153ffd | ||
|
|
47e3eb9b5b | ||
|
|
b8b07adeef | ||
|
|
d0e9e37ef6 | ||
|
|
820f92d8cb | ||
|
|
e42523af84 | ||
|
|
e2184d5e06 | ||
|
|
7fe0353260 | ||
|
|
0f2eba507e | ||
|
|
55e08474f3 | ||
|
|
28bdc52e1d | ||
|
|
e4221fa6c3 | ||
|
|
1652db9a2d | ||
|
|
601f17653a | ||
|
|
7718190fcd | ||
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f |
13
.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
||||
.git
|
||||
.github
|
||||
.venv
|
||||
__pycache__
|
||||
*.pyc
|
||||
.pytest_cache
|
||||
.mypy_cache
|
||||
.ruff_cache
|
||||
.cache
|
||||
.tmp
|
||||
.secrets
|
||||
dist
|
||||
build
|
||||
61
.github/workflows/publish-docker.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: Publish Docker Images
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Image tag to publish (without image suffix)"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
IMAGE_TAG: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag || github.ref_name }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- image_suffix: cpu-diarization-sortformer
|
||||
dockerfile: Dockerfile.cpu
|
||||
extras: cpu,diarization-sortformer
|
||||
- image_suffix: cu129-diarization-sortformer
|
||||
dockerfile: Dockerfile
|
||||
extras: cu129,diarization-sortformer
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set lowercase owner
|
||||
id: owner
|
||||
run: echo "value=${GITHUB_REPOSITORY_OWNER,,}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./${{ matrix.dockerfile }}
|
||||
push: true
|
||||
build-args: |
|
||||
EXTRAS=${{ matrix.extras }}
|
||||
tags: |
|
||||
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:${{ env.IMAGE_TAG }}-${{ matrix.image_suffix }}
|
||||
ghcr.io/${{ steps.owner.outputs.value }}/whisperlivekit:latest-${{ matrix.image_suffix }}
|
||||
23
.gitignore
vendored
@@ -54,21 +54,6 @@ coverage.xml
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -134,7 +119,11 @@ run_*.sh
|
||||
*.pt
|
||||
|
||||
# Debug & testing
|
||||
test_*.py
|
||||
/test_*.py
|
||||
!test_backend_offline.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
/test/
|
||||
!tests/
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
*.mp3
|
||||
205
BENCHMARK.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# WhisperLiveKit Benchmark Report
|
||||
|
||||
Benchmark comparing all supported ASR backends, streaming policies, and model sizes on Apple Silicon.
|
||||
All tests run through the full AudioProcessor pipeline (same code path as production WebSocket).
|
||||
|
||||
## Test Environment
|
||||
|
||||
| Property | Value |
|
||||
|----------|-------|
|
||||
| Hardware | Apple M4, 32 GB RAM |
|
||||
| OS | macOS 25.3.0 (arm64) |
|
||||
| Python | 3.13 |
|
||||
| faster-whisper | 1.2.1 |
|
||||
| mlx-whisper | installed (via mlx) |
|
||||
| Voxtral MLX | native MLX backend |
|
||||
| Voxtral (HF) | transformers-based |
|
||||
| VAC (Silero VAD) | enabled unless noted |
|
||||
| Chunk size | 100 ms |
|
||||
| Pacing | no-realtime (as fast as possible) |
|
||||
|
||||
## Audio Test Files
|
||||
|
||||
| File | Duration | Language | Speakers | Description |
|
||||
|------|----------|----------|----------|-------------|
|
||||
| `00_00_07_english_1_speaker.wav` | 7.2 s | English | 1 | Short dictation with pauses |
|
||||
| `00_00_16_french_1_speaker.wav` | 16.3 s | French | 1 | French speech with intentional silence gaps |
|
||||
| `00_00_30_english_3_speakers.wav` | 30.0 s | English | 3 | Multi-speaker conversation |
|
||||
|
||||
Ground truth transcripts (`.transcript.json`) with per-word timestamps are hand-verified.
|
||||
|
||||
---
|
||||
|
||||
## Results
|
||||
|
||||
### English -- Short (7.2 s, 1 speaker)
|
||||
|
||||
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||
|---------|--------|-------|-----|-----|---------------|
|
||||
| faster-whisper | LocalAgreement | base | 0.20x | 21.1% | 0.080 s |
|
||||
| faster-whisper | SimulStreaming | base | 0.14x | 0.0% | 0.239 s |
|
||||
| faster-whisper | LocalAgreement | small | 0.59x | 21.1% | 0.089 s |
|
||||
| faster-whisper | SimulStreaming | small | 0.39x | 0.0% | 0.221 s |
|
||||
| mlx-whisper | LocalAgreement | base | 0.05x | 21.1% | 0.080 s |
|
||||
| mlx-whisper | SimulStreaming | base | 0.14x | 10.5% | 0.245 s |
|
||||
| mlx-whisper | LocalAgreement | small | 0.16x | 21.1% | 0.089 s |
|
||||
| mlx-whisper | SimulStreaming | small | 0.20x | 10.5% | 0.226 s |
|
||||
| voxtral-mlx | voxtral | 4B | 0.32x | 0.0% | 0.254 s |
|
||||
| voxtral (HF) | voxtral | 4B | 1.29x | 0.0% | 1.876 s |
|
||||
|
||||
### English -- Multi-speaker (30.0 s, 3 speakers)
|
||||
|
||||
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||
|---------|--------|-------|-----|-----|---------------|
|
||||
| faster-whisper | LocalAgreement | base | 0.24x | 44.7% | 0.235 s |
|
||||
| faster-whisper | SimulStreaming | base | 0.10x | 5.3% | 0.398 s |
|
||||
| faster-whisper | LocalAgreement | small | 0.59x | 25.0% | 0.226 s |
|
||||
| faster-whisper | SimulStreaming | small | 0.26x | 5.3% | 0.387 s |
|
||||
| mlx-whisper | LocalAgreement | base | 0.06x | 23.7% | 0.237 s |
|
||||
| mlx-whisper | SimulStreaming | base | 0.11x | 5.3% | 0.395 s |
|
||||
| mlx-whisper | LocalAgreement | small | 0.13x | 25.0% | 0.226 s |
|
||||
| mlx-whisper | SimulStreaming | small | 0.20x | 5.3% | 0.394 s |
|
||||
| voxtral-mlx | voxtral | 4B | 0.31x | 9.2% | 0.176 s |
|
||||
| voxtral (HF) | voxtral | 4B | 1.00x | 32.9% | 1.034 s |
|
||||
|
||||
<p align="center">
|
||||
<img src="benchmark_chart.png" alt="Benchmark comparison on 30s English" width="800">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
|
||||
</p>
|
||||
|
||||
### French (16.3 s, 1 speaker, `--language fr`)
|
||||
|
||||
| Backend | Policy | Model | RTF | WER | Timestamp MAE |
|
||||
|---------|--------|-------|-----|-----|---------------|
|
||||
| faster-whisper | LocalAgreement | base | 0.22x | 25.7% | 3.460 s |
|
||||
| faster-whisper | SimulStreaming | base | 0.10x | 31.4% | 3.660 s |
|
||||
| faster-whisper | LocalAgreement | small | 0.76x | 42.9% | 0.051 s |
|
||||
| faster-whisper | SimulStreaming | small | 0.29x | 25.7% | 0.219 s |
|
||||
| mlx-whisper | LocalAgreement | base | 0.09x | ~45%* | ~5.0 s* |
|
||||
| mlx-whisper | SimulStreaming | base | 0.09x | 40.0% | 3.540 s |
|
||||
| mlx-whisper | LocalAgreement | small | 0.14x | 25.7% | 0.083 s |
|
||||
| mlx-whisper | SimulStreaming | small | 0.17x | 31.4% | 0.203 s |
|
||||
| voxtral-mlx | voxtral | 4B | 0.18x | 37.1% | 3.422 s |
|
||||
| voxtral (HF) | voxtral | 4B | 0.63x | 28.6% | 4.040 s |
|
||||
|
||||
\* mlx-whisper + LocalAgreement + base is unstable on this French file (WER fluctuates 34-1037% across runs due to hallucination loops). The `small` model does not have this problem.
|
||||
|
||||
**Timestamp note:** The base model produces very high timestamp MAE (3.4-3.7s) on this French file because it misaligns words around the silence gaps. The small model handles this much better (0.05-0.22s MAE). Voxtral also drifts on the silence gaps.
|
||||
|
||||
---
|
||||
|
||||
## Model Size Comparison (base vs small)
|
||||
|
||||
| | base | small | Observation |
|
||||
|--|------|-------|-------------|
|
||||
| **RTF** | 0.05-0.24x | 0.13-0.76x | small is 2-3x slower |
|
||||
| **English WER (SS)** | 0-5.3% | 0-5.3% | No improvement: SimulStreaming already saturates on base |
|
||||
| **English WER (LA)** | 21-44.7% | 21-25% | small reduces LA errors on longer audio |
|
||||
| **French WER** | 25-40% | 25-43% | Mixed: depends on backend/policy combo |
|
||||
| **French timestamps** | 3.4-5.0s MAE | 0.05-0.22s MAE | small is dramatically better for French timestamps |
|
||||
|
||||
In short: **base + SimulStreaming** gives the best speed/accuracy tradeoff for English. The small model only helps if you need LocalAgreement (for subtitle-grade timestamps) or non-English languages.
|
||||
|
||||
---
|
||||
|
||||
## Key Findings
|
||||
|
||||
### Speed (RTF = processing time / audio duration, lower is better)
|
||||
|
||||
1. **mlx-whisper + LocalAgreement + base** is the fastest combo on Apple Silicon: 0.05-0.06x RTF on English. 30 seconds of audio in under 2 seconds.
|
||||
2. For **faster-whisper**, SimulStreaming is faster than LocalAgreement. For **mlx-whisper**, it is the opposite: LocalAgreement (0.05-0.06x) outperforms SimulStreaming (0.11-0.14x) on speed.
|
||||
3. **voxtral-mlx** runs at 0.18-0.32x RTF -- 3-5x slower than mlx-whisper base, but well within real-time.
|
||||
4. **voxtral (HF transformers)** hits 1.0-1.3x RTF. At the real-time boundary on Apple Silicon. Use the MLX variant instead.
|
||||
5. The **small** model is 2-3x slower than base across all backends.
|
||||
|
||||
### Accuracy (WER = Word Error Rate, lower is better)
|
||||
|
||||
1. **SimulStreaming** gives dramatically lower WER than LocalAgreement on the whisper backends. On the 30s English file: 5.3% vs 23-44%.
|
||||
2. **voxtral-mlx** hits 0% on short English and 9.2% on multi-speaker. It auto-detects language natively. Whisper also supports `--language auto`, but tends to bias towards English on short segments.
|
||||
3. **LocalAgreement** tends to repeat the last sentence at end-of-stream (a known LCP artifact), inflating WER. This is visible in the 21% WER on the 7s file -- the same 4 extra words appear in every LA run.
|
||||
4. On **French** with the correct `--language fr`, whisper base achieves 25-40% WER -- comparable to Voxtral's 28-37%. The small model does not consistently improve French WER.
|
||||
|
||||
### Timestamps (MAE = Mean Absolute Error on word start times)
|
||||
|
||||
1. **LocalAgreement** gives the best timestamps on English (0.08-0.09s MAE).
|
||||
2. **SimulStreaming** is less precise (0.22-0.40s MAE) but good enough for most applications.
|
||||
3. On French with silence gaps, **base model timestamps are unreliable** (3.4-5s MAE). The **small model fixes this** (0.05-0.22s MAE). This is the strongest argument for using `small` over `base`.
|
||||
4. **voxtral-mlx** has good timestamps on English (0.18-0.25s MAE) but drifts on audio with long silence gaps (3.4s MAE on the French file).
|
||||
|
||||
### VAC (Voice Activity Classification) Impact
|
||||
|
||||
| Backend | Policy | VAC | 7s English WER | 30s English WER |
|
||||
|---------|--------|-----|----------------|-----------------|
|
||||
| faster-whisper | LocalAgreement | on | 21.1% | 44.7% |
|
||||
| faster-whisper | LocalAgreement | off | 100.0% | 100.0% |
|
||||
| voxtral-mlx | voxtral | on | 0.0% | 9.2% |
|
||||
| voxtral-mlx | voxtral | off | 0.0% | 9.2% |
|
||||
|
||||
- **Whisper backends need VAC** to work in streaming mode. Without it the buffer logic breaks down and you get empty or garbage output.
|
||||
- **Voxtral is unaffected by VAC** since it handles its own internal chunking. Identical results with or without. VAC still saves compute on silent segments.
|
||||
|
||||
---
|
||||
|
||||
## Recommendations
|
||||
|
||||
| Use Case | Backend | Policy | Model | Notes |
|
||||
|----------|---------|--------|-------|-------|
|
||||
| Fastest English (Apple Silicon) | mlx-whisper | SimulStreaming | base | 0.11x RTF, 5.3% WER |
|
||||
| Fastest English (Linux/GPU) | faster-whisper | SimulStreaming | base | 0.10x RTF, 5.3% WER |
|
||||
| Best accuracy, English | faster-whisper | SimulStreaming | small | 0.26x RTF, 5.3% WER, still fast |
|
||||
| Multilingual / auto-detect | voxtral-mlx | voxtral | 4B | 100+ languages, 0.18-0.32x RTF |
|
||||
| Best timestamps | any | LocalAgreement | small | 0.05-0.09s MAE, good for subtitles |
|
||||
| Low memory / embedded | mlx-whisper | SimulStreaming | base | Smallest footprint, fastest response |
|
||||
|
||||
---
|
||||
|
||||
## Caveats
|
||||
|
||||
- **3 test files, ~53 seconds total.** Results give relative rankings between backends but should not be taken as definitive WER numbers. Run on your own data for production decisions.
|
||||
- **RTF varies between runs** (up to +/-30%) depending on thermal state, background processes, and model caching. The numbers above are single sequential runs on a warm machine.
|
||||
- **Only base and small tested.** Medium and large-v3 would likely improve WER at the cost of higher RTF. We did not test them here because they are slow on Apple Silicon without GPU.
|
||||
|
||||
---
|
||||
|
||||
## Reproducing These Benchmarks
|
||||
|
||||
```bash
|
||||
# Install test dependencies
|
||||
pip install -e ".[test]"
|
||||
|
||||
# Single backend test
|
||||
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --model base --no-realtime
|
||||
|
||||
# With a specific language
|
||||
python test_backend_offline.py --backend mlx-whisper --policy simulstreaming --model small --lan fr --no-realtime
|
||||
|
||||
# Multi-backend auto-detect benchmark
|
||||
python test_backend_offline.py --benchmark --no-realtime
|
||||
|
||||
# Export to JSON
|
||||
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||
|
||||
# Test with your own audio
|
||||
python test_backend_offline.py --backend voxtral-mlx --audio your_file.wav --no-realtime
|
||||
```
|
||||
|
||||
The benchmark harness computes WER and timestamp accuracy automatically when ground truth
|
||||
`.transcript.json` files exist alongside the audio files. See `audio_tests/` for the format.
|
||||
|
||||
---
|
||||
|
||||
## Help Us Benchmark on More Hardware
|
||||
|
||||
These results are from a single Apple M4 machine. We'd love to see numbers from other setups: Linux with CUDA GPUs, older Macs, different CPU architectures, cloud instances, etc.
|
||||
|
||||
If you run the benchmark on your hardware, please open an issue or PR with your results and we will add them here. The more data points we have, the better the recommendations get.
|
||||
|
||||
What we are especially interested in:
|
||||
- **NVIDIA GPUs** (RTX 3090, 4090, A100, T4, etc.) with faster-whisper
|
||||
- **Older Apple Silicon** (M1, M2, M3) with mlx-whisper and voxtral-mlx
|
||||
- **Medium and large-v3 models** (we only tested base and small so far)
|
||||
- **Longer audio files** or domain-specific audio (medical, legal, call center)
|
||||
- **Other languages** beyond English and French
|
||||
@@ -15,7 +15,7 @@ Thank you for considering contributing ! We appreciate your time and effort to h
|
||||
|
||||
## Opening Issues
|
||||
|
||||
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||
|
||||
- **Bug Reports:**
|
||||
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||
@@ -43,4 +43,4 @@ We welcome and appreciate contributions! To ensure a smooth review process, plea
|
||||
|
||||
## Thank You
|
||||
|
||||
Your contributions make diart better for everyone. Thank you for your time and dedication!
|
||||
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!
|
||||
|
||||
91
DEV_NOTES.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 1. Simulstreaming: Decouple the encoder for faster inference
|
||||
|
||||
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
|
||||
|
||||
On macOS Apple Silicon M4 :
|
||||
|
||||
| Encoder | base.en | small |
|
||||
|--------|---------|-------|
|
||||
| WHISPER (no modification) | 0.35s | 1.09s |
|
||||
| FASTER_WHISPER | 0.4s | 1.20s |
|
||||
| MLX_WHISPER | 0.07s | 0.20s |
|
||||
|
||||
Memory saved by only loading encoder for optimized framework:
|
||||
|
||||
For tiny.en, mlx whisper:
|
||||
Sizes MLX whisper:
|
||||
Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
# 2. Translation: Faster model for each system
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||
|
||||
### Standard Transformers vs CTranslate2
|
||||
|
||||
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||
|-----------|-------------------------|---------------------------|---------|
|
||||
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||
|
||||
**Results:**
|
||||
- Total Standard time: 4.1068s
|
||||
- Total CTranslate2 time: 8.5476s
|
||||
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||
|
||||
|
||||
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
## Problem Statement
|
||||
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
|
||||
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
|
||||
|
||||
#
|
||||
### Initial Setup
|
||||
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
|
||||
|
||||
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
|
||||
```
|
||||
|
||||
- `DS_a_{i}`: Top detected speaker for prediction i
|
||||
- `DS_b_{i}`: Second detected speaker for prediction i
|
||||
- `AS_{i}`: Attributed speaker for prediction i
|
||||
- `GTS_A`: Ground truth speaker A
|
||||
- `GTS_B`: Ground truth speaker B
|
||||
- `DIST(a, b)`: Distance between detected speakers a and b
|
||||
|
||||
3. **Attribution Logic**
|
||||
|
||||
```
|
||||
AS_0 ← A
|
||||
|
||||
AS_1 ← B
|
||||
|
||||
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
|
||||
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
|
||||
# Likely that DS_a_0 = DS_a_1 (same speaker)
|
||||
AS_1 ← A
|
||||
AS_2 ← B
|
||||
|
||||
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
|
||||
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
|
||||
AS_2 ← A
|
||||
|
||||
ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
119
Dockerfile
@@ -1,82 +1,75 @@
|
||||
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
|
||||
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||
|
||||
# --- MARK: Builder Stage
|
||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 AS builder-gpu
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
# Install system dependencies
|
||||
#RUN apt-get update && \
|
||||
# apt-get install -y ffmpeg git && \
|
||||
# apt-get clean && \
|
||||
# rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 2) Install system dependencies + Python + pip
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
git && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
# Install UV and set up the environment
|
||||
COPY --from=uvbin /uv /uvx /bin/
|
||||
|
||||
COPY . .
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
# Note: For gates models, need to add your HF toke. See README.md
|
||||
# for more details.
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir .[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir .; \
|
||||
fi
|
||||
RUN uv python install 3.12
|
||||
|
||||
# Enable in-container caching for Hugging Face models by:
|
||||
# Note: If running multiple containers, better to map a shared
|
||||
# bucket.
|
||||
#
|
||||
# A) Make the cache directory persistent via an anonymous volume.
|
||||
# Note: This only persists for a single, named container. This is
|
||||
# only for convenience at de/test stage.
|
||||
# For prod, it is better to use a named volume via host mount/k8s.
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
# Install dependencies first to leverage caching
|
||||
ARG EXTRAS=cu129
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN set -eux; \
|
||||
set --; \
|
||||
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||
set -- "$@" --extra "$extra"; \
|
||||
done; \
|
||||
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||
|
||||
# or
|
||||
# B) Conditionally copy a local pre-cache from the build context to the
|
||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||
# WARNING: This will copy ALL files in the pre-cache location.
|
||||
# Copy the source code and install the package only
|
||||
COPY whisperlivekit /app/whisperlivekit
|
||||
RUN set -eux; \
|
||||
set --; \
|
||||
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||
set -- "$@" --extra "$extra"; \
|
||||
done; \
|
||||
uv sync --frozen --no-editable --no-cache "$@"
|
||||
|
||||
# Conditionally copy a cache directory if provided
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
# --- MARK: Runtime Stage
|
||||
FROM nvidia/cuda:12.9.1-cudnn-runtime-ubuntu24.04
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy UV binaries
|
||||
COPY --from=uvbin /uv /uvx /bin/
|
||||
|
||||
# Copy the Python version
|
||||
COPY --from=builder-gpu --chown=python:python /python /python
|
||||
|
||||
# Copy the virtual environment with all dependencies installed
|
||||
COPY --from=builder-gpu /app/.venv /app/.venv
|
||||
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
EXPOSE 8000
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_PYTHON_DOWNLOADS=0
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args
|
||||
CMD ["--model", "tiny.en"]
|
||||
CMD ["--model", "medium"]
|
||||
|
||||
76
Dockerfile.cpu
Normal file
@@ -0,0 +1,76 @@
|
||||
FROM ghcr.io/astral-sh/uv:0.10.4 AS uvbin
|
||||
|
||||
# --- MARK: Builder Stage
|
||||
FROM debian:bookworm-slim AS builder-cpu
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install UV and set up the environment
|
||||
COPY --from=uvbin /uv /uvx /bin/
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_NO_DEV=1
|
||||
ENV UV_PYTHON_PREFERENCE=only-managed
|
||||
ENV UV_PYTHON_INSTALL_DIR=/python
|
||||
|
||||
RUN uv python install 3.12
|
||||
|
||||
# Install dependencies first to leverage caching
|
||||
ARG EXTRAS=cpu
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN set -eux; \
|
||||
set --; \
|
||||
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||
set -- "$@" --extra "$extra"; \
|
||||
done; \
|
||||
uv sync --frozen --no-install-project --no-editable --no-cache "$@"
|
||||
|
||||
# Copy the source code and install the package only
|
||||
COPY whisperlivekit /app/whisperlivekit
|
||||
RUN set -eux; \
|
||||
set --; \
|
||||
for extra in $(echo "${EXTRAS:-}" | tr ',' ' '); do \
|
||||
set -- "$@" --extra "$extra"; \
|
||||
done; \
|
||||
uv sync --frozen --no-editable --no-cache "$@"
|
||||
|
||||
# --- MARK: Runtime Stage
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy UV binaries
|
||||
COPY --from=uvbin /uv /uvx /bin/
|
||||
|
||||
# Copy the Python version
|
||||
COPY --from=builder-cpu --chown=python:python /python /python
|
||||
|
||||
# Copy the virtual environment with all dependencies installed
|
||||
COPY --from=builder-cpu /app/.venv /app/.venv
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_PYTHON_DOWNLOADS=0
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/')" || exit 1
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args - you might want to use a smaller model for CPU
|
||||
CMD ["--model", "tiny"]
|
||||
226
LICENSE
@@ -1,52 +1,210 @@
|
||||
# License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
## Main Software License
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
MIT License
|
||||
1. Definitions.
|
||||
|
||||
Copyright (c) 2025 Quentin Fuxa.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
## SimulStreaming Backend License
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
**When using the SimulStreaming backend (SimulWhisper), additional licensing terms apply:**
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
SimulStreaming (https://github.com/ufal/SimulStreaming) is dual-licensed:
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
### 🔹 Non-Commercial Use
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you obtain the code through the GitHub repository. This license is **free of charge** and comes with **no obligations** for non-commercial users.
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
### 🔸 Commercial Use
|
||||
Understanding who uses SimulStreaming commercially helps improve and prioritize development. Therefore, **registration is required** for those who acquire a commercial license.
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
Commercial licenses are planned to be **affordable** to SMEs and individuals. They are considering providing commercial licenses either for free or for a symbolic one-time fee, and may also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft.com/e/7tCxb4gJfB).
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft.com/e/7tCxb4gJfB) to be notified when commercial licenses become available.
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
**Contact for SimulStreaming licensing:**
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Quentin Fuxa
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
---
|
||||
|
||||
## Based on:
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
||||
- **SimulStreaming** by ÚFAL – Dual License (PolyForm Noncommercial License 1.0.0 / Commercial License) – https://github.com/ufal/SimulStreaming
|
||||
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||
|
||||
409
README.md
@@ -1,143 +1,159 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
<h1 align="center">WLK</h1>
|
||||
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
|
||||
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
|
||||
</a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
||||
### Powered by Leading Research:
|
||||
|
||||
**See the interactive playground in [this repo](https://github.com/QuentinFuxa/streamlit-d3-network) to explore how AlignAtt works**
|
||||
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408).
|
||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
- [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) (2025) - 4B-parameter multilingual speech model by Mistral AI
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||
|
||||
|
||||
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
|
||||
|
||||
|
||||
### Architecture
|
||||
|
||||
WhisperLiveKit consists of three main components:
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
||||
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
|
||||
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
||||
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||
- **Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
||||
- **Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
||||
- **Buffering Preview** – Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
|
||||
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
|
||||
|
||||
## Quick Start
|
||||
### Installation & Quick Start
|
||||
|
||||
```bash
|
||||
# Install the package
|
||||
pip install whisperlivekit
|
||||
|
||||
# Start the transcription server
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Open your browser at http://localhost:8000 to see the interface.
|
||||
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
|
||||
```
|
||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||
|
||||
That's it! Start speaking and watch your words appear on screen.
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
wlk --model base --language en
|
||||
```
|
||||
|
||||
## Installation
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
|
||||
|
||||
> - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
Go to `chrome-extension` for instructions.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
#### Optional Dependencies
|
||||
|
||||
| Feature | `uv sync` | `pip install -e` |
|
||||
|-----------|-------------|-------------|
|
||||
| **Apple Silicon MLX Whisper backend** | `uv sync --extra mlx-whisper` | `pip install -e ".[mlx-whisper]"` |
|
||||
| **Voxtral (MLX backend, Apple Silicon)** | `uv sync --extra voxtral-mlx` | `pip install -e ".[voxtral-mlx]"` |
|
||||
| **CPU PyTorch stack** | `uv sync --extra cpu` | `pip install -e ".[cpu]"` |
|
||||
| **CUDA 12.9 PyTorch stack** | `uv sync --extra cu129` | `pip install -e ".[cu129]"` |
|
||||
| **Translation** | `uv sync --extra translation` | `pip install -e ".[translation]"` |
|
||||
| **Sentence tokenizer** | `uv sync --extra sentence_tokenizer` | `pip install -e ".[sentence_tokenizer]"` |
|
||||
| **Voxtral (HF backend)** | `uv sync --extra voxtral-hf` | `pip install -e ".[voxtral-hf]"` |
|
||||
| **Speaker diarization (Sortformer / NeMo)** | `uv sync --extra diarization-sortformer` | `pip install -e ".[diarization-sortformer]"` |
|
||||
| *[Not recommended]* Speaker diarization with Diart | `uv sync --extra diarization-diart` | `pip install -e ".[diarization-diart]"` |
|
||||
|
||||
Supported GPU profiles:
|
||||
|
||||
```bash
|
||||
#Install from PyPI (Recommended)
|
||||
pip install whisperlivekit
|
||||
# Profile A: Sortformer diarization
|
||||
uv sync --extra cu129 --extra diarization-sortformer
|
||||
|
||||
#Install from Source
|
||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
||||
cd WhisperLiveKit
|
||||
pip install -e .
|
||||
# Profile B: Voxtral HF + translation
|
||||
uv sync --extra cu129 --extra voxtral-hf --extra translation
|
||||
```
|
||||
|
||||
### FFmpeg Dependency
|
||||
`voxtral-hf` and `diarization-sortformer` are intentionally incompatible extras and must be installed in separate environments.
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
<p align="center">
|
||||
<img src="benchmark_scatter.png" alt="Speed vs Accuracy tradeoff" width="700">
|
||||
</p>
|
||||
|
||||
See **[BENCHMARK.md](BENCHMARK.md)** for the full benchmark with tables, model size comparison, and more.
|
||||
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
|
||||
|
||||
|
||||
|
||||
### Voxtral Backend
|
||||
|
||||
WhisperLiveKit supports [Voxtral Mini](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602),
|
||||
a 4B-parameter speech model from Mistral AI that natively handles 100+ languages with automatic
|
||||
language detection. Whisper also supports auto-detection (`--language auto`), but Voxtral's per-chunk
|
||||
detection is more reliable and does not bias towards English.
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install ffmpeg
|
||||
# Apple Silicon (native MLX, recommended)
|
||||
pip install -e ".[voxtral-mlx]"
|
||||
wlk --backend voxtral-mlx
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows
|
||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
||||
# Linux/GPU (HuggingFace transformers)
|
||||
pip install transformers torch
|
||||
wlk --backend voxtral
|
||||
```
|
||||
|
||||
### Optional Dependencies
|
||||
Voxtral uses its own streaming policy and does not use LocalAgreement or SimulStreaming.
|
||||
See [BENCHMARK.md](BENCHMARK.md) for performance numbers.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
**Command-line Interface**: Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Voice Activity Controller (prevents hallucinations)
|
||||
pip install torch
|
||||
# Large model and translate from french to danish
|
||||
wlk --model large-v3 --language fr --target-language da
|
||||
|
||||
# Sentence-based buffer trimming
|
||||
pip install mosestokenizer wtpsplit
|
||||
pip install tokenize_uk # If you work with Ukrainian text
|
||||
# Diarization and server listening on */80
|
||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
|
||||
# Speaker diarization
|
||||
pip install diart
|
||||
|
||||
# Alternative Whisper backends (default is faster-whisper)
|
||||
pip install whisperlivekit[whisper] # Original Whisper
|
||||
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
|
||||
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
|
||||
pip install whisperlivekit[openai] # OpenAI API
|
||||
pip install whisperlivekit[simulstreaming]
|
||||
```
|
||||
|
||||
### 🎹 Pyannote Models Setup
|
||||
|
||||
For diarization, you need access to pyannote.audio models:
|
||||
|
||||
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
4. Login with HuggingFace:
|
||||
```bash
|
||||
pip install huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 💻 Usage Examples
|
||||
|
||||
### Command-line Interface
|
||||
|
||||
Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Basic server with English model
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Advanced configuration with diarization
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
|
||||
|
||||
# SimulStreaming backend for ultra-low latency
|
||||
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
|
||||
# Voxtral multilingual (auto-detects language)
|
||||
wlk --backend voxtral-mlx
|
||||
```
|
||||
|
||||
|
||||
### Python API Integration (Backend)
|
||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
||||
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@@ -145,14 +161,10 @@ transcription_engine = None
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
# You can also load from command-line arguments using parse_args()
|
||||
# args = parse_args()
|
||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Process WebSocket connections
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
@@ -172,44 +184,49 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
### Frontend Implementation
|
||||
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
|
||||
|
||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or load its content using `get_web_interface_html()` :
|
||||
|
||||
```python
|
||||
from whisperlivekit import get_web_interface_html
|
||||
html_content = get_web_interface_html()
|
||||
```
|
||||
## Parameters & Configuration
|
||||
|
||||
## ⚙️ Configuration Reference
|
||||
|
||||
WhisperLiveKit offers extensive configuration options:
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
||||
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | ASR backend selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. Options: `mlx-whisper`, `faster-whisper`, `whisper`, `openai-api` (LocalAgreement only), `voxtral-mlx` (Apple Silicon), `voxtral` (HuggingFace) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--host` | Server host address | `localhost` |
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
|
||||
| `--language` | Source language code or `auto` | `en` |
|
||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||
| `--backend` | Processing backend | `faster-whisper` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--vac` | Use Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
|
||||
|
||||
**SimulStreaming-specific Options:**
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||
| `None` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -219,82 +236,124 @@ WhisperLiveKit offers extensive configuration options:
|
||||
| `--never-fire` | Never truncate incomplete words | `False` |
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
|
||||
2. **Streaming**: Audio chunks are sent to the server via WebSocket
|
||||
3. **Processing**: Server decodes audio with FFmpeg and streams into the model for transcription
|
||||
4. **Real-time Output**: Partial transcriptions appear immediately in light gray (the 'aperçu') and finalized text appears in normal color
|
||||
|
||||
## 🚀 Deployment Guide
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
|
||||
|
||||
|
||||
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
To deploy WhisperLiveKit in production:
|
||||
|
||||
1. **Server Setup** (Backend):
|
||||
|
||||
1. **Server Setup**: Install production ASGI server & launch with multiple workers
|
||||
```bash
|
||||
# Install production ASGI server
|
||||
pip install uvicorn gunicorn
|
||||
|
||||
# Launch with multiple workers
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **Frontend Integration**:
|
||||
- Host your customized version of the example HTML/JS in your web application
|
||||
- Ensure WebSocket connection points to your server's address
|
||||
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
|
||||
|
||||
3. **Nginx Configuration** (recommended for production):
|
||||
```nginx
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||
|
||||
### 🐋 Docker
|
||||
## 🐋 Docker
|
||||
|
||||
A basic Dockerfile is provided which allows re-use of Python package installation options. ⚠️ For **large** models, ensure that your **docker runtime** has enough **memory** available. See below usage examples:
|
||||
Deploy the application easily using Docker with GPU or CPU support.
|
||||
|
||||
### Prerequisites
|
||||
- Docker installed on your system
|
||||
- For GPU support: NVIDIA Docker runtime installed
|
||||
|
||||
#### All defaults
|
||||
- Create a reusable image with only the basics and then run as a named container:
|
||||
### Quick Start
|
||||
|
||||
**With GPU acceleration (recommended):**
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
||||
docker start -i whisperlivekit
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||
**CPU only:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk --build-arg EXTRAS="cpu" .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
|
||||
**Custom configuration:**
|
||||
```bash
|
||||
# Example with custom model and language
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
**Compose (recommended for cache + token wiring):**
|
||||
```bash
|
||||
# GPU Sortformer profile
|
||||
docker compose up --build wlk-gpu-sortformer
|
||||
|
||||
# GPU Voxtral profile
|
||||
docker compose up --build wlk-gpu-voxtral
|
||||
|
||||
# CPU service
|
||||
docker compose up --build wlk-cpu
|
||||
```
|
||||
|
||||
### Memory Requirements
|
||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||
|
||||
|
||||
#### Customization
|
||||
- Customize the container options:
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
||||
docker start -i whisperlivekit-base
|
||||
```
|
||||
|
||||
- `--build-arg` Options:
|
||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||
- `EXTRAS="cu129,diarization-sortformer"` - GPU Sortformer profile extras.
|
||||
- `EXTRAS="cu129,voxtral-hf,translation"` - GPU Voxtral profile extras.
|
||||
- `EXTRAS="cpu,diarization-diart,translation"` - CPU profile extras.
|
||||
- Hugging Face cache + token are configured in `compose.yml` using a named volume and `HF_TKN_FILE` (default: `./token`).
|
||||
|
||||
## 🔮 Use Cases
|
||||
## Testing & Benchmarks
|
||||
|
||||
WhisperLiveKit includes a unit test suite and an offline benchmark harness.
|
||||
|
||||
```bash
|
||||
# Install test dependencies
|
||||
pip install -e ".[test]"
|
||||
|
||||
# Run unit tests (no model download required)
|
||||
pytest tests/ -v
|
||||
|
||||
# Benchmark a single backend
|
||||
python test_backend_offline.py --backend faster-whisper --no-realtime
|
||||
|
||||
# Benchmark all installed backends
|
||||
python test_backend_offline.py --benchmark --no-realtime
|
||||
|
||||
# Export benchmark results as JSON
|
||||
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||
```
|
||||
|
||||
See [BENCHMARK.md](BENCHMARK.md) for a full comparison of backends, policies, WER, speed, and
|
||||
timestamp accuracy on Apple Silicon.
|
||||
|
||||
## Use Cases
|
||||
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
We extend our gratitude to the original authors of:
|
||||
|
||||
| [Whisper Streaming](https://github.com/ufal/whisper_streaming) | [SimulStreaming](https://github.com/ufal/SimulStreaming) | [Diart](https://github.com/juanmc2005/diart) | [OpenAI Whisper](https://github.com/openai/whisper) |
|
||||
| -------- | ------- | -------- | ------- |
|
||||
|
||||
BIN
architecture.png
Normal file
|
After Width: | Height: | Size: 422 KiB |
97
audio_tests/00_00_07_english_1_speaker.transcript.json
Normal file
@@ -0,0 +1,97 @@
|
||||
[
|
||||
{
|
||||
"word": "This",
|
||||
"start": 0.0,
|
||||
"end": 0.24
|
||||
},
|
||||
{
|
||||
"word": "is",
|
||||
"start": 0.24,
|
||||
"end": 0.56
|
||||
},
|
||||
{
|
||||
"word": "a",
|
||||
"start": 0.56,
|
||||
"end": 0.76
|
||||
},
|
||||
{
|
||||
"word": "transcription",
|
||||
"start": 0.76,
|
||||
"end": 1.32
|
||||
},
|
||||
{
|
||||
"word": "test.",
|
||||
"start": 1.32,
|
||||
"end": 2.0
|
||||
},
|
||||
{
|
||||
"word": "We",
|
||||
"start": 2.4,
|
||||
"end": 2.5
|
||||
},
|
||||
{
|
||||
"word": "want",
|
||||
"start": 2.5,
|
||||
"end": 2.66
|
||||
},
|
||||
{
|
||||
"word": "to",
|
||||
"start": 2.66,
|
||||
"end": 2.84
|
||||
},
|
||||
{
|
||||
"word": "see",
|
||||
"start": 2.84,
|
||||
"end": 3.1
|
||||
},
|
||||
{
|
||||
"word": "if",
|
||||
"start": 3.1,
|
||||
"end": 3.34
|
||||
},
|
||||
{
|
||||
"word": "we",
|
||||
"start": 3.34,
|
||||
"end": 3.5
|
||||
},
|
||||
{
|
||||
"word": "can",
|
||||
"start": 3.5,
|
||||
"end": 3.68
|
||||
},
|
||||
{
|
||||
"word": "use",
|
||||
"start": 3.68,
|
||||
"end": 4.04
|
||||
},
|
||||
{
|
||||
"word": "smaller",
|
||||
"start": 4.04,
|
||||
"end": 4.76
|
||||
},
|
||||
{
|
||||
"word": "chunks.",
|
||||
"start": 4.76,
|
||||
"end": 5.16
|
||||
},
|
||||
{
|
||||
"word": "What",
|
||||
"start": 6.06,
|
||||
"end": 6.32
|
||||
},
|
||||
{
|
||||
"word": "do",
|
||||
"start": 6.32,
|
||||
"end": 6.44
|
||||
},
|
||||
{
|
||||
"word": "you",
|
||||
"start": 6.44,
|
||||
"end": 6.58
|
||||
},
|
||||
{
|
||||
"word": "think?",
|
||||
"start": 6.58,
|
||||
"end": 6.84
|
||||
}
|
||||
]
|
||||
177
audio_tests/00_00_16_french_1_speaker.transcript.json
Normal file
@@ -0,0 +1,177 @@
|
||||
[
|
||||
{
|
||||
"word": "Ok,",
|
||||
"start": 2.02,
|
||||
"end": 2.38
|
||||
},
|
||||
{
|
||||
"word": "là",
|
||||
"start": 2.52,
|
||||
"end": 2.58
|
||||
},
|
||||
{
|
||||
"word": "c",
|
||||
"start": 2.58,
|
||||
"end": 2.74
|
||||
},
|
||||
{
|
||||
"word": "'est",
|
||||
"start": 2.74,
|
||||
"end": 2.76
|
||||
},
|
||||
{
|
||||
"word": "un",
|
||||
"start": 2.76,
|
||||
"end": 2.86
|
||||
},
|
||||
{
|
||||
"word": "test,",
|
||||
"start": 2.86,
|
||||
"end": 3.2
|
||||
},
|
||||
{
|
||||
"word": "on",
|
||||
"start": 3.34,
|
||||
"end": 3.34
|
||||
},
|
||||
{
|
||||
"word": "veut",
|
||||
"start": 3.34,
|
||||
"end": 3.48
|
||||
},
|
||||
{
|
||||
"word": "voir",
|
||||
"start": 3.48,
|
||||
"end": 3.86
|
||||
},
|
||||
{
|
||||
"word": "si",
|
||||
"start": 3.86,
|
||||
"end": 4.14
|
||||
},
|
||||
{
|
||||
"word": "ça",
|
||||
"start": 4.14,
|
||||
"end": 4.26
|
||||
},
|
||||
{
|
||||
"word": "arrive",
|
||||
"start": 4.26,
|
||||
"end": 4.36
|
||||
},
|
||||
{
|
||||
"word": "à",
|
||||
"start": 4.36,
|
||||
"end": 4.5
|
||||
},
|
||||
{
|
||||
"word": "capté",
|
||||
"start": 4.5,
|
||||
"end": 4.78
|
||||
},
|
||||
{
|
||||
"word": "le",
|
||||
"start": 4.78,
|
||||
"end": 4.9
|
||||
},
|
||||
{
|
||||
"word": "silence.",
|
||||
"start": 4.9,
|
||||
"end": 5.44
|
||||
},
|
||||
{
|
||||
"word": "Là",
|
||||
"start": 9.24,
|
||||
"end": 9.6
|
||||
},
|
||||
{
|
||||
"word": "il",
|
||||
"start": 9.6,
|
||||
"end": 9.78
|
||||
},
|
||||
{
|
||||
"word": "est",
|
||||
"start": 9.78,
|
||||
"end": 9.84
|
||||
},
|
||||
{
|
||||
"word": "une",
|
||||
"start": 9.84,
|
||||
"end": 9.96
|
||||
},
|
||||
{
|
||||
"word": "telle",
|
||||
"start": 9.96,
|
||||
"end": 10.12
|
||||
},
|
||||
{
|
||||
"word": "seconde",
|
||||
"start": 10.12,
|
||||
"end": 10.38
|
||||
},
|
||||
{
|
||||
"word": "de",
|
||||
"start": 10.38,
|
||||
"end": 10.48
|
||||
},
|
||||
{
|
||||
"word": "silence",
|
||||
"start": 10.48,
|
||||
"end": 10.78
|
||||
},
|
||||
{
|
||||
"word": "et",
|
||||
"start": 10.78,
|
||||
"end": 11.06
|
||||
},
|
||||
{
|
||||
"word": "je",
|
||||
"start": 11.06,
|
||||
"end": 11.16
|
||||
},
|
||||
{
|
||||
"word": "vous",
|
||||
"start": 11.16,
|
||||
"end": 11.32
|
||||
},
|
||||
{
|
||||
"word": "parle.",
|
||||
"start": 11.32,
|
||||
"end": 11.68
|
||||
},
|
||||
{
|
||||
"word": "Et",
|
||||
"start": 13.28,
|
||||
"end": 13.64
|
||||
},
|
||||
{
|
||||
"word": "voilà,",
|
||||
"start": 13.64,
|
||||
"end": 13.96
|
||||
},
|
||||
{
|
||||
"word": "allez",
|
||||
"start": 14.36,
|
||||
"end": 14.62
|
||||
},
|
||||
{
|
||||
"word": "on",
|
||||
"start": 14.62,
|
||||
"end": 14.78
|
||||
},
|
||||
{
|
||||
"word": "va",
|
||||
"start": 14.78,
|
||||
"end": 14.88
|
||||
},
|
||||
{
|
||||
"word": "tester",
|
||||
"start": 14.88,
|
||||
"end": 15.06
|
||||
},
|
||||
{
|
||||
"word": "ça.",
|
||||
"start": 15.06,
|
||||
"end": 15.36
|
||||
}
|
||||
]
|
||||
382
audio_tests/00_00_30_english_3_speakers.transcript.json
Normal file
@@ -0,0 +1,382 @@
|
||||
[
|
||||
{
|
||||
"word": "Transcription",
|
||||
"start": 0.0,
|
||||
"end": 0.6
|
||||
},
|
||||
{
|
||||
"word": "technology",
|
||||
"start": 0.6,
|
||||
"end": 1.24
|
||||
},
|
||||
{
|
||||
"word": "has",
|
||||
"start": 1.24,
|
||||
"end": 1.5
|
||||
},
|
||||
{
|
||||
"word": "improved",
|
||||
"start": 1.5,
|
||||
"end": 1.96
|
||||
},
|
||||
{
|
||||
"word": "so",
|
||||
"start": 1.96,
|
||||
"end": 2.32
|
||||
},
|
||||
{
|
||||
"word": "much",
|
||||
"start": 2.32,
|
||||
"end": 2.68
|
||||
},
|
||||
{
|
||||
"word": "in",
|
||||
"start": 2.68,
|
||||
"end": 2.94
|
||||
},
|
||||
{
|
||||
"word": "the",
|
||||
"start": 2.94,
|
||||
"end": 3.02
|
||||
},
|
||||
{
|
||||
"word": "past",
|
||||
"start": 3.02,
|
||||
"end": 3.24
|
||||
},
|
||||
{
|
||||
"word": "few",
|
||||
"start": 3.24,
|
||||
"end": 3.5
|
||||
},
|
||||
{
|
||||
"word": "years.",
|
||||
"start": 3.5,
|
||||
"end": 3.96
|
||||
},
|
||||
{
|
||||
"word": "Have",
|
||||
"start": 4.56,
|
||||
"end": 4.74
|
||||
},
|
||||
{
|
||||
"word": "you",
|
||||
"start": 4.74,
|
||||
"end": 4.9
|
||||
},
|
||||
{
|
||||
"word": "noticed",
|
||||
"start": 4.9,
|
||||
"end": 5.26
|
||||
},
|
||||
{
|
||||
"word": "how",
|
||||
"start": 5.26,
|
||||
"end": 5.52
|
||||
},
|
||||
{
|
||||
"word": "accurate",
|
||||
"start": 5.52,
|
||||
"end": 6.08
|
||||
},
|
||||
{
|
||||
"word": "real",
|
||||
"start": 6.08,
|
||||
"end": 6.42
|
||||
},
|
||||
{
|
||||
"word": "-time",
|
||||
"start": 6.42,
|
||||
"end": 6.74
|
||||
},
|
||||
{
|
||||
"word": "speech",
|
||||
"start": 6.74,
|
||||
"end": 7.24
|
||||
},
|
||||
{
|
||||
"word": "to",
|
||||
"start": 7.24,
|
||||
"end": 7.46
|
||||
},
|
||||
{
|
||||
"word": "text",
|
||||
"start": 7.46,
|
||||
"end": 7.78
|
||||
},
|
||||
{
|
||||
"word": "is",
|
||||
"start": 7.78,
|
||||
"end": 8.0
|
||||
},
|
||||
{
|
||||
"word": "now?",
|
||||
"start": 8.0,
|
||||
"end": 8.3
|
||||
},
|
||||
{
|
||||
"word": "Absolutely.",
|
||||
"start": 8.7,
|
||||
"end": 9.16
|
||||
},
|
||||
{
|
||||
"word": "I",
|
||||
"start": 10.04,
|
||||
"end": 10.38
|
||||
},
|
||||
{
|
||||
"word": "use",
|
||||
"start": 10.38,
|
||||
"end": 10.56
|
||||
},
|
||||
{
|
||||
"word": "it",
|
||||
"start": 10.56,
|
||||
"end": 10.76
|
||||
},
|
||||
{
|
||||
"word": "all",
|
||||
"start": 10.76,
|
||||
"end": 10.9
|
||||
},
|
||||
{
|
||||
"word": "the",
|
||||
"start": 10.9,
|
||||
"end": 11.04
|
||||
},
|
||||
{
|
||||
"word": "time",
|
||||
"start": 11.04,
|
||||
"end": 11.32
|
||||
},
|
||||
{
|
||||
"word": "for",
|
||||
"start": 11.32,
|
||||
"end": 11.54
|
||||
},
|
||||
{
|
||||
"word": "taking",
|
||||
"start": 11.54,
|
||||
"end": 11.86
|
||||
},
|
||||
{
|
||||
"word": "notes",
|
||||
"start": 11.86,
|
||||
"end": 12.16
|
||||
},
|
||||
{
|
||||
"word": "during",
|
||||
"start": 12.16,
|
||||
"end": 12.54
|
||||
},
|
||||
{
|
||||
"word": "meetings.",
|
||||
"start": 12.54,
|
||||
"end": 12.94
|
||||
},
|
||||
{
|
||||
"word": "It's",
|
||||
"start": 13.6,
|
||||
"end": 13.8
|
||||
},
|
||||
{
|
||||
"word": "amazing",
|
||||
"start": 13.8,
|
||||
"end": 14.1
|
||||
},
|
||||
{
|
||||
"word": "how",
|
||||
"start": 14.1,
|
||||
"end": 14.48
|
||||
},
|
||||
{
|
||||
"word": "it",
|
||||
"start": 14.48,
|
||||
"end": 14.62
|
||||
},
|
||||
{
|
||||
"word": "can",
|
||||
"start": 14.62,
|
||||
"end": 14.74
|
||||
},
|
||||
{
|
||||
"word": "recognise",
|
||||
"start": 14.74,
|
||||
"end": 15.24
|
||||
},
|
||||
{
|
||||
"word": "different",
|
||||
"start": 15.24,
|
||||
"end": 15.68
|
||||
},
|
||||
{
|
||||
"word": "speakers",
|
||||
"start": 15.68,
|
||||
"end": 16.16
|
||||
},
|
||||
{
|
||||
"word": "and",
|
||||
"start": 16.16,
|
||||
"end": 16.8
|
||||
},
|
||||
{
|
||||
"word": "even",
|
||||
"start": 16.8,
|
||||
"end": 17.1
|
||||
},
|
||||
{
|
||||
"word": "add",
|
||||
"start": 17.1,
|
||||
"end": 17.44
|
||||
},
|
||||
{
|
||||
"word": "punctuation.",
|
||||
"start": 17.44,
|
||||
"end": 18.36
|
||||
},
|
||||
{
|
||||
"word": "Yeah,",
|
||||
"start": 18.88,
|
||||
"end": 19.16
|
||||
},
|
||||
{
|
||||
"word": "but",
|
||||
"start": 19.36,
|
||||
"end": 19.52
|
||||
},
|
||||
{
|
||||
"word": "sometimes",
|
||||
"start": 19.52,
|
||||
"end": 20.16
|
||||
},
|
||||
{
|
||||
"word": "noise",
|
||||
"start": 20.16,
|
||||
"end": 20.54
|
||||
},
|
||||
{
|
||||
"word": "can",
|
||||
"start": 20.54,
|
||||
"end": 20.8
|
||||
},
|
||||
{
|
||||
"word": "still",
|
||||
"start": 20.8,
|
||||
"end": 21.1
|
||||
},
|
||||
{
|
||||
"word": "cause",
|
||||
"start": 21.1,
|
||||
"end": 21.44
|
||||
},
|
||||
{
|
||||
"word": "mistakes.",
|
||||
"start": 21.44,
|
||||
"end": 21.94
|
||||
},
|
||||
{
|
||||
"word": "Does",
|
||||
"start": 22.68,
|
||||
"end": 22.9
|
||||
},
|
||||
{
|
||||
"word": "this",
|
||||
"start": 22.9,
|
||||
"end": 23.12
|
||||
},
|
||||
{
|
||||
"word": "system",
|
||||
"start": 23.12,
|
||||
"end": 23.46
|
||||
},
|
||||
{
|
||||
"word": "handle",
|
||||
"start": 23.46,
|
||||
"end": 23.88
|
||||
},
|
||||
{
|
||||
"word": "that",
|
||||
"start": 23.88,
|
||||
"end": 24.12
|
||||
},
|
||||
{
|
||||
"word": "well?",
|
||||
"start": 24.12,
|
||||
"end": 24.42
|
||||
},
|
||||
{
|
||||
"word": "It",
|
||||
"start": 24.42,
|
||||
"end": 25.32
|
||||
},
|
||||
{
|
||||
"word": "does",
|
||||
"start": 25.32,
|
||||
"end": 25.48
|
||||
},
|
||||
{
|
||||
"word": "a",
|
||||
"start": 25.48,
|
||||
"end": 25.62
|
||||
},
|
||||
{
|
||||
"word": "pretty",
|
||||
"start": 25.62,
|
||||
"end": 25.88
|
||||
},
|
||||
{
|
||||
"word": "good",
|
||||
"start": 25.88,
|
||||
"end": 26.08
|
||||
},
|
||||
{
|
||||
"word": "job",
|
||||
"start": 26.08,
|
||||
"end": 26.32
|
||||
},
|
||||
{
|
||||
"word": "filtering",
|
||||
"start": 26.32,
|
||||
"end": 26.8
|
||||
},
|
||||
{
|
||||
"word": "noise,",
|
||||
"start": 26.8,
|
||||
"end": 27.18
|
||||
},
|
||||
{
|
||||
"word": "especially",
|
||||
"start": 27.36,
|
||||
"end": 28.0
|
||||
},
|
||||
{
|
||||
"word": "with",
|
||||
"start": 28.0,
|
||||
"end": 28.28
|
||||
},
|
||||
{
|
||||
"word": "models",
|
||||
"start": 28.28,
|
||||
"end": 28.62
|
||||
},
|
||||
{
|
||||
"word": "that",
|
||||
"start": 28.62,
|
||||
"end": 28.94
|
||||
},
|
||||
{
|
||||
"word": "use",
|
||||
"start": 28.94,
|
||||
"end": 29.22
|
||||
},
|
||||
{
|
||||
"word": "voice",
|
||||
"start": 29.22,
|
||||
"end": 29.54
|
||||
},
|
||||
{
|
||||
"word": "active.",
|
||||
"start": 29.54,
|
||||
"end": 29.9
|
||||
}
|
||||
]
|
||||
57
audio_tests/generate_transcripts.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate word-level timestamped transcripts using faster-whisper (offline).
|
||||
|
||||
Produces one JSON file per audio with: [{word, start, end}, ...]
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
FILES = [
|
||||
("00_00_07_english_1_speaker.wav", "en"),
|
||||
("00_00_16_french_1_speaker.wav", "fr"),
|
||||
("00_00_30_english_3_speakers.wav", "en"),
|
||||
]
|
||||
|
||||
def main():
|
||||
print("Loading faster-whisper model (base, cpu, float32)...")
|
||||
model = WhisperModel("base", device="cpu", compute_type="float32")
|
||||
|
||||
for filename, lang in FILES:
|
||||
audio_path = os.path.join(AUDIO_DIR, filename)
|
||||
out_path = os.path.join(
|
||||
AUDIO_DIR, filename.rsplit(".", 1)[0] + ".transcript.json"
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Transcribing: {filename} (language={lang})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
segments, info = model.transcribe(
|
||||
audio_path, word_timestamps=True, language=lang
|
||||
)
|
||||
|
||||
words = []
|
||||
for segment in segments:
|
||||
if segment.words:
|
||||
for w in segment.words:
|
||||
words.append({
|
||||
"word": w.word.strip(),
|
||||
"start": round(w.start, 3),
|
||||
"end": round(w.end, 3),
|
||||
})
|
||||
print(f" {w.start:6.2f} - {w.end:6.2f} {w.word.strip()}")
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(words, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n -> {len(words)} words written to {os.path.basename(out_path)}")
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
benchmark_chart.png
Normal file
|
After Width: | Height: | Size: 69 KiB |
BIN
benchmark_scatter.png
Normal file
|
After Width: | Height: | Size: 95 KiB |
19
chrome-extension/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## WhisperLiveKit Chrome Extension v0.1.1
|
||||
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
|
||||
|
||||
> Currently, only the tab audio is captured; your microphone audio is not recorded.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||
|
||||
## Running this extension
|
||||
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||
|
||||
|
||||
## Devs:
|
||||
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
|
||||
- https://issues.chromium.org/issues/40926394
|
||||
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
|
||||
- https://issues.chromium.org/issues/40916430
|
||||
|
||||
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)
|
||||
9
chrome-extension/background.js
Normal file
@@ -0,0 +1,9 @@
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
BIN
chrome-extension/demo-extension.png
Normal file
|
After Width: | Height: | Size: 5.8 MiB |
BIN
chrome-extension/icons/icon128.png
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
chrome-extension/icons/icon16.png
Normal file
|
After Width: | Height: | Size: 376 B |
BIN
chrome-extension/icons/icon32.png
Normal file
|
After Width: | Height: | Size: 823 B |
BIN
chrome-extension/icons/icon48.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
23
chrome-extension/manifest.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "WhisperLiveKit Tab Capture",
|
||||
"version": "1.0",
|
||||
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||
"icons": {
|
||||
"16": "icons/icon16.png",
|
||||
"32": "icons/icon32.png",
|
||||
"48": "icons/icon48.png",
|
||||
"128": "icons/icon128.png"
|
||||
},
|
||||
"action": {
|
||||
"default_title": "WhisperLiveKit Tab Capture",
|
||||
"default_popup": "live_transcription.html"
|
||||
},
|
||||
"permissions": [
|
||||
"scripting",
|
||||
"tabCapture",
|
||||
"offscreen",
|
||||
"activeTab",
|
||||
"storage"
|
||||
]
|
||||
}
|
||||
12
chrome-extension/requestPermissions.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Request Permissions</title>
|
||||
<script src="requestPermissions.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<button id="requestMicrophone">Request Microphone</button>
|
||||
</body>
|
||||
</html>
|
||||
17
chrome-extension/requestPermissions.js
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Requests user permission for microphone access.
|
||||
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
|
||||
*/
|
||||
async function getUserPermission() {
|
||||
console.log("Getting user permission for microphone access...");
|
||||
await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state == "granted") {
|
||||
window.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Call the function to request microphone permission
|
||||
getUserPermission();
|
||||
29
chrome-extension/sidepanel.js
Normal file
@@ -0,0 +1,29 @@
|
||||
console.log("sidepanel.js");
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "requestPermissions.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
52
compose.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
services:
|
||||
wlk-gpu-sortformer:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
EXTRAS: ${GPU_SORTFORMER_EXTRAS:-cu129,diarization-sortformer}
|
||||
image: wlk:gpu-sortformer
|
||||
gpus: all
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- hf-cache:/root/.cache/huggingface/hub
|
||||
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||
environment:
|
||||
- HF_TOKEN
|
||||
command: ["--model", "medium", "--diarization", "--pcm-input"]
|
||||
|
||||
wlk-gpu-voxtral:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
EXTRAS: ${GPU_VOXTRAL_EXTRAS:-cu129,voxtral-hf,translation}
|
||||
image: wlk:gpu-voxtral
|
||||
gpus: all
|
||||
ports:
|
||||
- "8001:8000"
|
||||
volumes:
|
||||
- hf-cache:/root/.cache/huggingface/hub
|
||||
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||
environment:
|
||||
- HF_TOKEN
|
||||
command: ["--backend", "voxtral", "--pcm-input"]
|
||||
|
||||
wlk-cpu:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.cpu
|
||||
args:
|
||||
EXTRAS: ${CPU_EXTRAS:-cpu,diarization-diart,translation}
|
||||
image: wlk:cpu
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- hf-cache:/root/.cache/huggingface/hub
|
||||
# - ${HF_TKN_FILE:-./token}:/root/.cache/huggingface/token:ro
|
||||
environment:
|
||||
- HF_TOKEN
|
||||
|
||||
volumes:
|
||||
hf-cache:
|
||||
BIN
demo.png
|
Before Width: | Height: | Size: 438 KiB After Width: | Height: | Size: 985 KiB |
264
docs/API.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# WhisperLiveKit WebSocket API Documentation
|
||||
|
||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||
This documentation is intended for devs who want to build custom frontends.
|
||||
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
|
||||
---
|
||||
|
||||
## Legacy API (Current)
|
||||
|
||||
### Message Structure
|
||||
|
||||
The current API sends complete state snapshots on each update (several time per second)
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": str,
|
||||
"status": str,
|
||||
"lines": [
|
||||
{
|
||||
"speaker": int,
|
||||
"text": str,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"translation": str | null,
|
||||
"detected_language": str
|
||||
}
|
||||
],
|
||||
"buffer_transcription": str,
|
||||
"buffer_diarization": str,
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## New API (Under Development)
|
||||
|
||||
### Philosophy
|
||||
|
||||
Principles:
|
||||
|
||||
- **Incremental Updates**: Only updates and new segments are sent
|
||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||
|
||||
|
||||
## Message Format
|
||||
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription" | "no_audio_detected",
|
||||
"segments": [
|
||||
{
|
||||
"id": number,
|
||||
"speaker": number,
|
||||
"text": string,
|
||||
"start_speaker": float,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"language": string | null,
|
||||
"translation": string,
|
||||
"words": [
|
||||
{
|
||||
"text": string,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"validated": {
|
||||
"text": boolean,
|
||||
"speaker": boolean,
|
||||
}
|
||||
}
|
||||
],
|
||||
"buffer": {
|
||||
"transcription": string,
|
||||
"diarization": string,
|
||||
"translation": string
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Other Message Types
|
||||
|
||||
#### Config Message (sent on connection)
|
||||
```json
|
||||
{
|
||||
"type": "config",
|
||||
"useAudioWorklet": true / false
|
||||
}
|
||||
```
|
||||
|
||||
#### Ready to Stop Message (sent after processing complete)
|
||||
```json
|
||||
{
|
||||
"type": "ready_to_stop"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Descriptions
|
||||
|
||||
### Segment Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||
|
||||
### Word Object
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `string` | The word text. |
|
||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||
|
||||
### Buffer Object (Per-Segment)
|
||||
|
||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||
|
||||
### Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `active_transcription` | Normal operation, transcription is active. |
|
||||
| `no_audio_detected` | No audio has been detected yet. |
|
||||
|
||||
---
|
||||
|
||||
## Update Behavior
|
||||
|
||||
### Incremental Updates
|
||||
|
||||
The API sends **only changed or new segments**. Clients should:
|
||||
|
||||
1. Maintain a local map of segments by ID
|
||||
2. When receiving an update, merge/update segments by ID
|
||||
3. Render only the changed segments
|
||||
|
||||
### Language Detection
|
||||
|
||||
When language is detected for a segment:
|
||||
|
||||
```jsonc
|
||||
// Update 1: No language yet
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||
]
|
||||
}
|
||||
|
||||
// Update 2: Same segment ID, language now detected
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||
|
||||
### Buffer Behavior
|
||||
|
||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||
|
||||
#### Example: Translation with diarization and translation
|
||||
|
||||
```jsonc
|
||||
// Update 1
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "Hello world, how are",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " you on",
|
||||
"translation": "Bonjour le monde"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||
|
||||
|
||||
// Update 2
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": " you on this",
|
||||
"translation": "Bonjour tout le monde",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " beautiful day",
|
||||
"translation": ",comment"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||
```
|
||||
|
||||
### Silence Segments
|
||||
|
||||
Silence is represented with the speaker id = `-2`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"id": 5,
|
||||
"speaker": -2,
|
||||
"text": "",
|
||||
"start": 10.5,
|
||||
"end": 12.3
|
||||
}
|
||||
```
|
||||
71
docs/alignement_principles.md
Normal file
@@ -0,0 +1,71 @@
|
||||
### Alignment between STT Tokens and Diarization Segments
|
||||
|
||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
||||
|
||||
|
||||
## Example 1:
|
||||
```text
|
||||
punctuations_segments : __#_______.__________________!____
|
||||
diarization_segments:
|
||||
SPK1 __#____________
|
||||
SPK2 # ___________________
|
||||
-->
|
||||
ALIGNED SPK1 __#_______.
|
||||
ALIGNED SPK2 # __________________!____
|
||||
|
||||
t-1 output:
|
||||
SPK1: __#
|
||||
SPK2: NO
|
||||
DIARIZATION BUFFER: NO
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: __________________!____
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 2:
|
||||
```text
|
||||
punctuations_segments : _____#__.___________
|
||||
diarization_segments:
|
||||
SPK1 ___ #
|
||||
SPK2 __#______________
|
||||
-->
|
||||
ALIGNED SPK1 _____#__.
|
||||
ALIGNED SPK2 # ___________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___ #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: ___________
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 3:
|
||||
```text
|
||||
punctuations_segments : ___.__#__________
|
||||
diarization_segments:
|
||||
SPK1 ______#__
|
||||
SPK2 # ________
|
||||
-->
|
||||
ALIGNED SPK1 ___. #
|
||||
ALIGNED SPK2 __#__________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___. #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: #
|
||||
SPK2: __#___________
|
||||
DIARIZATION BUFFER: NO
|
||||
```
|
||||
106
docs/default_and_custom_models.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# Models and Model Paths
|
||||
|
||||
## Defaults
|
||||
|
||||
**Default Whisper Model**: `base`
|
||||
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
|
||||
|
||||
**Default Model Cache Directory**: `~/.cache/whisper`
|
||||
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
|
||||
|
||||
**Default Translation Model**: `600M` (NLLB-200-distilled)
|
||||
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
|
||||
|
||||
**Default Translation Backend**: `transformers`
|
||||
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Available Whisper model sizes:
|
||||
|
||||
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|
||||
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
|
||||
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
|
||||
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
|
||||
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
|
||||
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
|
||||
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
|
||||
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
|
||||
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
|
||||
|
||||
|
||||
### How to choose?
|
||||
|
||||
#### Language Support
|
||||
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
|
||||
- **Multilingual**: Do not use `.en` models.
|
||||
|
||||
#### Special Cases
|
||||
- **No translation needed**: Use `large-v3-turbo`
|
||||
- Same transcription quality as `large-v2` but significantly faster
|
||||
- **Important**: Does not translate correctly, only transcribes
|
||||
|
||||
### Additional Considerations
|
||||
|
||||
**Model Performance**:
|
||||
- Accuracy improves significantly from tiny to large models
|
||||
- English-only models are ~10-15% more accurate for English audio
|
||||
- Newer versions (v2, v3) have better punctuation and formatting
|
||||
|
||||
**Audio Quality Impact**:
|
||||
- Clean, clear audio: smaller models may suffice
|
||||
- Noisy, accented, or technical audio: larger models recommended
|
||||
- Phone/low-quality audio: use at least `small` model
|
||||
|
||||
_______________________
|
||||
|
||||
|
||||
# Custom Models:
|
||||
|
||||
The `--model-path` parameter accepts:
|
||||
|
||||
## File Path
|
||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||
|
||||
## Directory Path (recommended)
|
||||
Must contain:
|
||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||
|
||||
May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
## Hugging Face Repo ID
|
||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||
|
||||
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
|
||||
|
||||
|
||||
_______________________
|
||||
|
||||
# Translation Models and Backend
|
||||
|
||||
**Language Support**: ~200 languages
|
||||
|
||||
## Distilled Model Sizes Available
|
||||
|
||||
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||
|-------|------|------------|-------------|-------------|---------|
|
||||
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||
|
||||
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||
|
||||
## Backend Performance
|
||||
|
||||
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||
|---------|---------------|--------------|--------------|
|
||||
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||
| Transformers | Baseline | High | None |
|
||||
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||
|
||||
**Metrics**:
|
||||
- CTranslate2: 50-100+ tokens/sec
|
||||
- Transformers: 10-30 tokens/sec
|
||||
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||
373
docs/supported_languages.md
Normal file
@@ -0,0 +1,373 @@
|
||||
# Transcription: Supported Language
|
||||
|
||||
WLK supports transcription in the following languages:
|
||||
|
||||
| ISO Code | Language Name |
|
||||
|----------|---------------------|
|
||||
| en | English |
|
||||
| zh | Chinese |
|
||||
| de | German |
|
||||
| es | Spanish |
|
||||
| ru | Russian |
|
||||
| ko | Korean |
|
||||
| fr | French |
|
||||
| ja | Japanese |
|
||||
| pt | Portuguese |
|
||||
| tr | Turkish |
|
||||
| pl | Polish |
|
||||
| ca | Catalan |
|
||||
| nl | Dutch |
|
||||
| ar | Arabic |
|
||||
| sv | Swedish |
|
||||
| it | Italian |
|
||||
| id | Indonesian |
|
||||
| hi | Hindi |
|
||||
| fi | Finnish |
|
||||
| vi | Vietnamese |
|
||||
| he | Hebrew |
|
||||
| uk | Ukrainian |
|
||||
| el | Greek |
|
||||
| ms | Malay |
|
||||
| cs | Czech |
|
||||
| ro | Romanian |
|
||||
| da | Danish |
|
||||
| hu | Hungarian |
|
||||
| ta | Tamil |
|
||||
| no | Norwegian |
|
||||
| th | Thai |
|
||||
| ur | Urdu |
|
||||
| hr | Croatian |
|
||||
| bg | Bulgarian |
|
||||
| lt | Lithuanian |
|
||||
| la | Latin |
|
||||
| mi | Maori |
|
||||
| ml | Malayalam |
|
||||
| cy | Welsh |
|
||||
| sk | Slovak |
|
||||
| te | Telugu |
|
||||
| fa | Persian |
|
||||
| lv | Latvian |
|
||||
| bn | Bengali |
|
||||
| sr | Serbian |
|
||||
| az | Azerbaijani |
|
||||
| sl | Slovenian |
|
||||
| kn | Kannada |
|
||||
| et | Estonian |
|
||||
| mk | Macedonian |
|
||||
| br | Breton |
|
||||
| eu | Basque |
|
||||
| is | Icelandic |
|
||||
| hy | Armenian |
|
||||
| ne | Nepali |
|
||||
| mn | Mongolian |
|
||||
| bs | Bosnian |
|
||||
| kk | Kazakh |
|
||||
| sq | Albanian |
|
||||
| sw | Swahili |
|
||||
| gl | Galician |
|
||||
| mr | Marathi |
|
||||
| pa | Punjabi |
|
||||
| si | Sinhala |
|
||||
| km | Khmer |
|
||||
| sn | Shona |
|
||||
| yo | Yoruba |
|
||||
| so | Somali |
|
||||
| af | Afrikaans |
|
||||
| oc | Occitan |
|
||||
| ka | Georgian |
|
||||
| be | Belarusian |
|
||||
| tg | Tajik |
|
||||
| sd | Sindhi |
|
||||
| gu | Gujarati |
|
||||
| am | Amharic |
|
||||
| yi | Yiddish |
|
||||
| lo | Lao |
|
||||
| uz | Uzbek |
|
||||
| fo | Faroese |
|
||||
| ht | Haitian Creole |
|
||||
| ps | Pashto |
|
||||
| tk | Turkmen |
|
||||
| nn | Nynorsk |
|
||||
| mt | Maltese |
|
||||
| sa | Sanskrit |
|
||||
| lb | Luxembourgish |
|
||||
| my | Myanmar |
|
||||
| bo | Tibetan |
|
||||
| tl | Tagalog |
|
||||
| mg | Malagasy |
|
||||
| as | Assamese |
|
||||
| tt | Tatar |
|
||||
| haw | Hawaiian |
|
||||
| ln | Lingala |
|
||||
| ha | Hausa |
|
||||
| ba | Bashkir |
|
||||
| jw | Javanese |
|
||||
| su | Sundanese |
|
||||
| yue | Cantonese |
|
||||
|
||||
|
||||
# Translation: Supported Languages
|
||||
|
||||
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
|
||||
|
||||
## How to Specify Languages
|
||||
|
||||
You can specify languages in **three different ways**:
|
||||
|
||||
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Command Line
|
||||
```bash
|
||||
# Using language name
|
||||
whisperlivekit-server --target-language "French"
|
||||
|
||||
# Using ISO code
|
||||
whisperlivekit-server --target-language fr
|
||||
|
||||
# Using NLLB code
|
||||
whisperlivekit-server --target-language fra_Latn
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from nllw.translation import get_language_info
|
||||
|
||||
# Get language information by name
|
||||
lang_info = get_language_info("French")
|
||||
print(lang_info)
|
||||
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||
|
||||
# Get language information by ISO code
|
||||
lang_info = get_language_info("fr")
|
||||
|
||||
# Get language information by NLLB code
|
||||
lang_info = get_language_info("fra_Latn")
|
||||
|
||||
# All three return the same result
|
||||
```
|
||||
|
||||
## Complete Language List
|
||||
|
||||
The following table lists all 201 supported languages with their corresponding codes:
|
||||
|
||||
| Language Name | ISO Code | NLLB Code |
|
||||
|---------------|----------|-----------|
|
||||
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||
| Afrikaans | af | afr_Latn |
|
||||
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||
| Akan | ak | aka_Latn |
|
||||
| Tosk Albanian | als | als_Latn |
|
||||
| Amharic | am | amh_Ethi |
|
||||
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||
| Modern Standard Arabic | ar | arb_Arab |
|
||||
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||
| Assamese | as | asm_Beng |
|
||||
| Asturian | ast | ast_Latn |
|
||||
| Awadhi | awa | awa_Deva |
|
||||
| Central Aymara | ay | ayr_Latn |
|
||||
| South Azerbaijani | azb | azb_Arab |
|
||||
| North Azerbaijani | az | azj_Latn |
|
||||
| Bashkir | ba | bak_Cyrl |
|
||||
| Bambara | bm | bam_Latn |
|
||||
| Balinese | ban | ban_Latn |
|
||||
| Belarusian | be | bel_Cyrl |
|
||||
| Bemba | bem | bem_Latn |
|
||||
| Bengali | bn | ben_Beng |
|
||||
| Bhojpuri | bho | bho_Deva |
|
||||
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||
| Standard Tibetan | bo | bod_Tibt |
|
||||
| Bosnian | bs | bos_Latn |
|
||||
| Buginese | bug | bug_Latn |
|
||||
| Bulgarian | bg | bul_Cyrl |
|
||||
| Catalan | ca | cat_Latn |
|
||||
| Cebuano | ceb | ceb_Latn |
|
||||
| Czech | cs | ces_Latn |
|
||||
| Chokwe | cjk | cjk_Latn |
|
||||
| Central Kurdish | ckb | ckb_Arab |
|
||||
| Crimean Tatar | crh | crh_Latn |
|
||||
| Welsh | cy | cym_Latn |
|
||||
| Danish | da | dan_Latn |
|
||||
| German | de | deu_Latn |
|
||||
| Southwestern Dinka | dik | dik_Latn |
|
||||
| Dyula | dyu | dyu_Latn |
|
||||
| Dzongkha | dz | dzo_Tibt |
|
||||
| Greek | el | ell_Grek |
|
||||
| English | en | eng_Latn |
|
||||
| Esperanto | eo | epo_Latn |
|
||||
| Estonian | et | est_Latn |
|
||||
| Basque | eu | eus_Latn |
|
||||
| Ewe | ee | ewe_Latn |
|
||||
| Faroese | fo | fao_Latn |
|
||||
| Fijian | fj | fij_Latn |
|
||||
| Finnish | fi | fin_Latn |
|
||||
| Fon | fon | fon_Latn |
|
||||
| French | fr | fra_Latn |
|
||||
| Friulian | fur-IT | fur_Latn |
|
||||
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||
| West Central Oromo | om | gaz_Latn |
|
||||
| Scottish Gaelic | gd | gla_Latn |
|
||||
| Irish | ga-IE | gle_Latn |
|
||||
| Galician | gl | glg_Latn |
|
||||
| Guarani | gn | grn_Latn |
|
||||
| Gujarati | gu-IN | guj_Gujr |
|
||||
| Haitian Creole | ht | hat_Latn |
|
||||
| Hausa | ha | hau_Latn |
|
||||
| Hebrew | he | heb_Hebr |
|
||||
| Hindi | hi | hin_Deva |
|
||||
| Chhattisgarhi | hne | hne_Deva |
|
||||
| Croatian | hr | hrv_Latn |
|
||||
| Hungarian | hu | hun_Latn |
|
||||
| Armenian | hy-AM | hye_Armn |
|
||||
| Igbo | ig | ibo_Latn |
|
||||
| Ilocano | ilo | ilo_Latn |
|
||||
| Indonesian | id | ind_Latn |
|
||||
| Icelandic | is | isl_Latn |
|
||||
| Italian | it | ita_Latn |
|
||||
| Javanese | jv | jav_Latn |
|
||||
| Japanese | ja | jpn_Jpan |
|
||||
| Kabyle | kab | kab_Latn |
|
||||
| Jingpho | kac | kac_Latn |
|
||||
| Kamba | kam | kam_Latn |
|
||||
| Kannada | kn | kan_Knda |
|
||||
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||
| Georgian | ka | kat_Geor |
|
||||
| Kazakh | kk | kaz_Cyrl |
|
||||
| Kabiyè | kbp | kbp_Latn |
|
||||
| Kabuverdianu | kea | kea_Latn |
|
||||
| Halh Mongolian | mn | khk_Cyrl |
|
||||
| Khmer | km | khm_Khmr |
|
||||
| Kikuyu | ki | kik_Latn |
|
||||
| Kinyarwanda | rw | kin_Latn |
|
||||
| Kyrgyz | ky | kir_Cyrl |
|
||||
| Kimbundu | kmb | kmb_Latn |
|
||||
| Northern Kurdish | kmr | kmr_Latn |
|
||||
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||
| Kikongo | kg | kon_Latn |
|
||||
| Korean | ko | kor_Hang |
|
||||
| Lao | lo | lao_Laoo |
|
||||
| Ligurian | lij | lij_Latn |
|
||||
| Limburgish | li | lim_Latn |
|
||||
| Lingala | ln | lin_Latn |
|
||||
| Lithuanian | lt | lit_Latn |
|
||||
| Lombard | lmo | lmo_Latn |
|
||||
| Latgalian | ltg | ltg_Latn |
|
||||
| Luxembourgish | lb | ltz_Latn |
|
||||
| Luba-Kasai | lua | lua_Latn |
|
||||
| Ganda | lg | lug_Latn |
|
||||
| Luo | luo | luo_Latn |
|
||||
| Mizo | lus | lus_Latn |
|
||||
| Standard Latvian | lv | lvs_Latn |
|
||||
| Magahi | mag | mag_Deva |
|
||||
| Maithili | mai | mai_Deva |
|
||||
| Malayalam | ml-IN | mal_Mlym |
|
||||
| Marathi | mr | mar_Deva |
|
||||
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||
| Macedonian | mk | mkd_Cyrl |
|
||||
| Maltese | mt | mlt_Latn |
|
||||
| Meitei (Bengali script) | mni | mni_Beng |
|
||||
| Mossi | mos | mos_Latn |
|
||||
| Maori | mi | mri_Latn |
|
||||
| Burmese | my | mya_Mymr |
|
||||
| Dutch | nl | nld_Latn |
|
||||
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||
| Norwegian Bokmål | nb | nob_Latn |
|
||||
| Nepali | ne-NP | npi_Deva |
|
||||
| Northern Sotho | nso | nso_Latn |
|
||||
| Nuer | nus | nus_Latn |
|
||||
| Nyanja | ny | nya_Latn |
|
||||
| Occitan | oc | oci_Latn |
|
||||
| Odia | or | ory_Orya |
|
||||
| Pangasinan | pag | pag_Latn |
|
||||
| Eastern Panjabi | pa | pan_Guru |
|
||||
| Papiamento | pap | pap_Latn |
|
||||
| Southern Pashto | pbt | pbt_Arab |
|
||||
| Western Persian | fa | pes_Arab |
|
||||
| Plateau Malagasy | mg | plt_Latn |
|
||||
| Polish | pl | pol_Latn |
|
||||
| Portuguese | pt-PT | por_Latn |
|
||||
| Dari | fa-AF | prs_Arab |
|
||||
| Ayacucho Quechua | qu | quy_Latn |
|
||||
| Romanian | ro | ron_Latn |
|
||||
| Rundi | rn | run_Latn |
|
||||
| Russian | ru | rus_Cyrl |
|
||||
| Sango | sg | sag_Latn |
|
||||
| Sanskrit | sa | san_Deva |
|
||||
| Santali | sat | sat_Olck |
|
||||
| Sicilian | scn | scn_Latn |
|
||||
| Shan | shn | shn_Mymr |
|
||||
| Sinhala | si-LK | sin_Sinh |
|
||||
| Slovak | sk | slk_Latn |
|
||||
| Slovenian | sl | slv_Latn |
|
||||
| Samoan | sm | smo_Latn |
|
||||
| Shona | sn | sna_Latn |
|
||||
| Sindhi | sd | snd_Arab |
|
||||
| Somali | so | som_Latn |
|
||||
| Southern Sotho | st | sot_Latn |
|
||||
| Spanish | es-ES | spa_Latn |
|
||||
| Sardinian | sc | srd_Latn |
|
||||
| Serbian | sr | srp_Cyrl |
|
||||
| Swati | ss | ssw_Latn |
|
||||
| Sundanese | su | sun_Latn |
|
||||
| Swedish | sv-SE | swe_Latn |
|
||||
| Swahili | sw | swh_Latn |
|
||||
| Silesian | szl | szl_Latn |
|
||||
| Tamil | ta | tam_Taml |
|
||||
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||
| Tatar | tt-RU | tat_Cyrl |
|
||||
| Telugu | te | tel_Telu |
|
||||
| Tajik | tg | tgk_Cyrl |
|
||||
| Tagalog | tl | tgl_Latn |
|
||||
| Thai | th | tha_Thai |
|
||||
| Tigrinya | ti | tir_Ethi |
|
||||
| Tok Pisin | tpi | tpi_Latn |
|
||||
| Tswana | tn | tsn_Latn |
|
||||
| Tsonga | ts | tso_Latn |
|
||||
| Turkmen | tk | tuk_Latn |
|
||||
| Tumbuka | tum | tum_Latn |
|
||||
| Turkish | tr | tur_Latn |
|
||||
| Twi | tw | twi_Latn |
|
||||
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||
| Uyghur | ug | uig_Arab |
|
||||
| Ukrainian | uk | ukr_Cyrl |
|
||||
| Umbundu | umb | umb_Latn |
|
||||
| Urdu | ur | urd_Arab |
|
||||
| Northern Uzbek | uz | uzn_Latn |
|
||||
| Venetian | vec | vec_Latn |
|
||||
| Vietnamese | vi | vie_Latn |
|
||||
| Waray | war | war_Latn |
|
||||
| Wolof | wo | wol_Latn |
|
||||
| Xhosa | xh | xho_Latn |
|
||||
| Eastern Yiddish | yi | ydd_Hebr |
|
||||
| Yoruba | yo | yor_Latn |
|
||||
| Yue Chinese | yue | yue_Hant |
|
||||
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||
| Standard Malay | ms | zsm_Latn |
|
||||
| Zulu | zu | zul_Latn |
|
||||
|
||||
## Special Features
|
||||
|
||||
### Multiple Script Support
|
||||
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||
43
docs/technical_integration.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Technical Integration Guide
|
||||
|
||||
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
|
||||
|
||||
---
|
||||
|
||||
## 1. Runtime Components
|
||||
|
||||
| Layer | File(s) | Purpose |
|
||||
|-------|---------|---------|
|
||||
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
|
||||
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
|
||||
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
|
||||
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
|
||||
|
||||
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
|
||||
|
||||
---
|
||||
|
||||
## 2. Running Without the Bundled Frontend
|
||||
|
||||
1. Start the server/engine however you like:
|
||||
```bash
|
||||
wlk --model small --language en --host 0.0.0.0 --port 9000
|
||||
# or launch your own app that instantiates TranscriptionEngine(...)
|
||||
```
|
||||
2. Build your own client (browser, mobile, desktop) that:
|
||||
- Opens `ws(s)://<host>:<port>/asr`
|
||||
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
|
||||
- Consumes the JSON payload defined in `docs/API.md`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Running Without FastAPI
|
||||
|
||||
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
|
||||
|
||||
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
|
||||
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
|
||||
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
||||
|
||||
|
||||
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.
|
||||
140
docs/troubleshooting.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Troubleshooting
|
||||
|
||||
|
||||
## GPU drivers & cuDNN visibility
|
||||
|
||||
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
|
||||
> Reported in issue #271 (Arch/CachyOS)
|
||||
|
||||
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
|
||||
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
|
||||
|
||||
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
|
||||
```bash
|
||||
sudo pacman -S cuda cudnn
|
||||
sudo ldconfig
|
||||
```
|
||||
2. **Make sure the shared objects are visible**:
|
||||
```bash
|
||||
ls /usr/lib/libcudnn*
|
||||
```
|
||||
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
|
||||
```bash
|
||||
python - <<'EOF'
|
||||
import torch
|
||||
print(torch.version.cuda)
|
||||
EOF
|
||||
nvcc --version
|
||||
```
|
||||
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
|
||||
|
||||
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
|
||||
|
||||
### Windows error: `Could not locate cudnn_ops64_9.dll`
|
||||
> Reported in issue #286 (Conda on Windows)
|
||||
|
||||
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
|
||||
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
|
||||
|
||||
1. Locate the DLL shipped with PyTorch, e.g.
|
||||
```
|
||||
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
|
||||
```
|
||||
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
|
||||
3. Restart the shell before launching `wlk`.
|
||||
|
||||
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
|
||||
|
||||
---
|
||||
|
||||
## PyTorch / CTranslate2 GPU builds
|
||||
|
||||
### `Torch not compiled with CUDA enabled`
|
||||
> Reported in issue #284
|
||||
|
||||
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
|
||||
Install the GPU-enabled wheels that match your CUDA toolkit:
|
||||
|
||||
```bash
|
||||
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
|
||||
```
|
||||
|
||||
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
|
||||
Validate with:
|
||||
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.is_available(), torch.cuda.get_device_name())
|
||||
```
|
||||
|
||||
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
|
||||
> Follow-up in issue #284
|
||||
|
||||
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
|
||||
|
||||
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
|
||||
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
|
||||
```bash
|
||||
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
|
||||
```
|
||||
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
|
||||
3. Verify:
|
||||
```python
|
||||
import ctranslate2
|
||||
print("CUDA devices:", ctranslate2.get_cuda_device_count())
|
||||
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
|
||||
```
|
||||
|
||||
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
|
||||
|
||||
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
|
||||
|
||||
---
|
||||
|
||||
## Hopper / Blackwell (`sm_121a`) systems
|
||||
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
|
||||
|
||||
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
|
||||
|
||||
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
|
||||
|
||||
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
|
||||
|
||||
```bash
|
||||
# Find your Python environment's Triton directory
|
||||
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
|
||||
|
||||
# Copy the system ptxas to Triton's backend directory
|
||||
# Replace <triton_path> with the output above
|
||||
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
|
||||
```
|
||||
|
||||
For example, in a virtual environment:
|
||||
```bash
|
||||
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
```
|
||||
|
||||
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
|
||||
|
||||
### Alternative: Environment variable approach
|
||||
|
||||
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
|
||||
|
||||
```bash
|
||||
export CUDA_HOME="/usr/local/cuda-13.0"
|
||||
export PATH="$CUDA_HOME/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
||||
|
||||
# Tell Triton where the new ptxas lives
|
||||
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
|
||||
|
||||
# Force PyTorch to JIT kernels for all needed architectures
|
||||
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
|
||||
```
|
||||
|
||||
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
|
||||
|
||||
---
|
||||
|
||||
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.
|
||||
|
||||
141
pyproject.toml
Normal file
@@ -0,0 +1,141 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.19"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Quentin Fuxa" }]
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.11, <3.14"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"huggingface-hub>=0.25.0",
|
||||
"faster-whisper>=1.2.0",
|
||||
"torch>=2.0.0",
|
||||
"torchaudio>=2.0.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest>=7.0", "pytest-asyncio>=0.21"]
|
||||
translation = ["nllw"]
|
||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||
mlx-whisper = [
|
||||
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
]
|
||||
voxtral-mlx = [
|
||||
'mlx>=0.11.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
'mlx-whisper>=0.4.0; sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
"mistral-common[audio]",
|
||||
]
|
||||
voxtral-hf = [
|
||||
"transformers>=5.2.0; python_version >= '3.10'",
|
||||
"mistral-common[audio]",
|
||||
"accelerate>=0.12",
|
||||
]
|
||||
cpu = ["torch>=2.0.0", "torchaudio>=2.0.0"]
|
||||
cu129 = [
|
||||
"torch>=2.0.0",
|
||||
"torchaudio>=2.0.0",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")',
|
||||
]
|
||||
diarization-sortformer = [
|
||||
"nemo-toolkit[asr]>2.4; python_version >= '3.10' and python_version < '3.13'",
|
||||
]
|
||||
diarization-diart = [
|
||||
"diart",
|
||||
"torch<2.9.0",
|
||||
"torchaudio<2.9.0",
|
||||
"torchvision<0.24.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["rich>=14.3.3"]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "cu129" },
|
||||
],
|
||||
[
|
||||
{ extra = "diarization-diart" },
|
||||
{ extra = "cu129" },
|
||||
],
|
||||
[
|
||||
{ extra = "voxtral-hf" },
|
||||
{ extra = "diarization-sortformer" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },
|
||||
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||
{ index = "pytorch-cu129", extra = "cu129", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" },
|
||||
]
|
||||
torchvision = [
|
||||
{ index = "pytorch-cpu", extra = "diarization-diart", marker = "platform_system != 'Darwin'" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu129"
|
||||
url = "https://download.pytorch.org/whl/cu129"
|
||||
explicit = true
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
wlk = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"whisperlivekit",
|
||||
"whisperlivekit.diarization",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.simul_whisper.mlx",
|
||||
"whisperlivekit.whisper",
|
||||
"whisperlivekit.whisper.assets",
|
||||
"whisperlivekit.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.local_agreement",
|
||||
"whisperlivekit.voxtral_mlx",
|
||||
"whisperlivekit.silero_vad_models",
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||
291
run_benchmark.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive benchmark runner for WhisperLiveKit.
|
||||
|
||||
Tests all available backend+policy combinations across multiple audio files,
|
||||
model sizes, and VAC on/off configurations. Outputs structured JSON that
|
||||
is consumed by the report generator.
|
||||
|
||||
Usage:
|
||||
python run_benchmark.py # full benchmark
|
||||
python run_benchmark.py --quick # subset (tiny models, fewer combos)
|
||||
python run_benchmark.py --json results.json # custom output path
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger("benchmark")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Re-use harness functions
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_backend_offline import (
|
||||
AUDIO_TESTS_DIR,
|
||||
SAMPLE_RATE,
|
||||
TestResult,
|
||||
create_engine,
|
||||
discover_audio_files,
|
||||
download_sample_audio,
|
||||
load_audio,
|
||||
run_test,
|
||||
)
|
||||
|
||||
CACHE_DIR = Path(__file__).parent / ".test_cache"
|
||||
|
||||
|
||||
def get_system_info() -> dict:
|
||||
"""Collect system metadata for the report."""
|
||||
info = {
|
||||
"platform": platform.platform(),
|
||||
"machine": platform.machine(),
|
||||
"processor": platform.processor(),
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
# macOS: get chip info
|
||||
try:
|
||||
chip = subprocess.check_output(
|
||||
["sysctl", "-n", "machdep.cpu.brand_string"], text=True
|
||||
).strip()
|
||||
info["cpu"] = chip
|
||||
except Exception:
|
||||
info["cpu"] = platform.processor()
|
||||
|
||||
# RAM
|
||||
try:
|
||||
mem_bytes = int(
|
||||
subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip()
|
||||
)
|
||||
info["ram_gb"] = round(mem_bytes / (1024**3))
|
||||
except Exception:
|
||||
info["ram_gb"] = None
|
||||
|
||||
# Backend versions
|
||||
versions = {}
|
||||
try:
|
||||
import faster_whisper
|
||||
versions["faster-whisper"] = faster_whisper.__version__
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import mlx_whisper # noqa: F401
|
||||
versions["mlx-whisper"] = "installed"
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import mlx.core as mx
|
||||
versions["mlx"] = mx.__version__
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import transformers
|
||||
versions["transformers"] = transformers.__version__
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import torch
|
||||
versions["torch"] = torch.__version__
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
info["backend_versions"] = versions
|
||||
return info
|
||||
|
||||
|
||||
def detect_combos(quick: bool = False) -> list:
|
||||
"""Build list of (backend, policy, model_size) combos to test."""
|
||||
combos = []
|
||||
|
||||
# Model sizes to test
|
||||
model_sizes = ["tiny", "base", "small"] if not quick else ["tiny", "base"]
|
||||
|
||||
# faster-whisper
|
||||
try:
|
||||
import faster_whisper # noqa: F401
|
||||
for model in model_sizes:
|
||||
combos.append({"backend": "faster-whisper", "policy": "localagreement", "model": model})
|
||||
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "model": model})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# mlx-whisper
|
||||
try:
|
||||
import mlx_whisper # noqa: F401
|
||||
for model in model_sizes:
|
||||
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "model": model})
|
||||
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "model": model})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# voxtral-mlx (single model, single policy)
|
||||
try:
|
||||
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
|
||||
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "model": ""})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# voxtral HF (single model, single policy)
|
||||
try:
|
||||
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
|
||||
combos.append({"backend": "voxtral", "policy": "voxtral", "model": ""})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return combos
|
||||
|
||||
|
||||
def collect_audio_files() -> list:
|
||||
"""Collect all benchmark audio files."""
|
||||
files = []
|
||||
|
||||
# audio_tests/ directory
|
||||
if AUDIO_TESTS_DIR.is_dir():
|
||||
files.extend(discover_audio_files(str(AUDIO_TESTS_DIR)))
|
||||
|
||||
# JFK sample
|
||||
jfk = CACHE_DIR / "jfk.wav"
|
||||
if not jfk.exists():
|
||||
jfk = download_sample_audio()
|
||||
if jfk.exists():
|
||||
files.append(jfk)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
async def run_single_combo(
|
||||
combo: dict, audio_files: list, vac: bool, lan: str, max_duration: float,
|
||||
) -> list:
|
||||
"""Run one backend+policy+model combo across all audio files."""
|
||||
backend = combo["backend"]
|
||||
policy = combo["policy"]
|
||||
model = combo["model"]
|
||||
|
||||
results = []
|
||||
try:
|
||||
engine = create_engine(
|
||||
backend=backend,
|
||||
model_size=model,
|
||||
lan=lan,
|
||||
vac=vac,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
# Quiet noisy loggers
|
||||
for mod in (
|
||||
"whisperlivekit.audio_processor",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.tokens_alignment",
|
||||
"whisperlivekit.simul_whisper.align_att_base",
|
||||
"whisperlivekit.simul_whisper.simul_whisper",
|
||||
):
|
||||
logging.getLogger(mod).setLevel(logging.WARNING)
|
||||
|
||||
for audio_path in audio_files:
|
||||
duration = len(load_audio(str(audio_path))) / SAMPLE_RATE
|
||||
if duration > max_duration:
|
||||
logger.info(f" Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s)")
|
||||
continue
|
||||
|
||||
file_lan = lan
|
||||
if "french" in audio_path.name.lower() and lan == "en":
|
||||
file_lan = "fr"
|
||||
|
||||
audio = load_audio(str(audio_path))
|
||||
result = await run_test(
|
||||
engine, audio, chunk_ms=100, realtime=False,
|
||||
audio_file=audio_path.name, backend=backend,
|
||||
policy=policy, lan=file_lan,
|
||||
)
|
||||
# Tag with extra metadata
|
||||
result_dict = asdict(result)
|
||||
result_dict["model_size"] = model
|
||||
result_dict["vac"] = vac
|
||||
results.append(result_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def run_full_benchmark(combos, audio_files, max_duration=60.0):
|
||||
"""Run all combos with VAC on and off."""
|
||||
all_results = []
|
||||
total = len(combos) * 2 # x2 for VAC on/off
|
||||
idx = 0
|
||||
|
||||
for combo in combos:
|
||||
for vac in [True, False]:
|
||||
idx += 1
|
||||
vac_str = "VAC=on" if vac else "VAC=off"
|
||||
desc = f"{combo['backend']} / {combo['policy']}"
|
||||
if combo["model"]:
|
||||
desc += f" / {combo['model']}"
|
||||
desc += f" / {vac_str}"
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"[{idx}/{total}] {desc}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
results = await run_single_combo(
|
||||
combo, audio_files, vac=vac, lan="en", max_duration=max_duration,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
# Free memory between combos
|
||||
gc.collect()
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run comprehensive WhisperLiveKit benchmark")
|
||||
parser.add_argument("--quick", action="store_true", help="Quick mode: fewer models and combos")
|
||||
parser.add_argument("--json", default="benchmark_results.json", dest="json_output", help="Output JSON path")
|
||||
parser.add_argument("--max-duration", type=float, default=60.0, help="Max audio duration in seconds")
|
||||
args = parser.parse_args()
|
||||
|
||||
system_info = get_system_info()
|
||||
combos = detect_combos(quick=args.quick)
|
||||
audio_files = collect_audio_files()
|
||||
|
||||
print(f"System: {system_info.get('cpu', 'unknown')}, {system_info.get('ram_gb', '?')}GB RAM")
|
||||
print(f"Backends: {list(system_info['backend_versions'].keys())}")
|
||||
print(f"Combos to test: {len(combos)} x 2 (VAC on/off) = {len(combos)*2}")
|
||||
print(f"Audio files: {[f.name for f in audio_files]}")
|
||||
print()
|
||||
|
||||
t0 = time.time()
|
||||
all_results = asyncio.run(
|
||||
run_full_benchmark(combos, audio_files, max_duration=args.max_duration)
|
||||
)
|
||||
total_time = time.time() - t0
|
||||
|
||||
output = {
|
||||
"system_info": system_info,
|
||||
"benchmark_date": time.strftime("%Y-%m-%d %H:%M"),
|
||||
"total_benchmark_time_s": round(total_time, 1),
|
||||
"n_combos": len(combos) * 2,
|
||||
"n_audio_files": len(audio_files),
|
||||
"results": all_results,
|
||||
}
|
||||
|
||||
Path(args.json_output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
print(f"\nBenchmark complete in {total_time:.0f}s. Results: {args.json_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
scripts/alignment_heads.png
Normal file
|
After Width: | Height: | Size: 276 KiB |
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||
|
||||
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||
encoder positional embeddings and updating the stored model dimensions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
safetensor_path = repo_path / "model.safetensors"
|
||||
bin_path = repo_path / "pytorch_model.bin"
|
||||
|
||||
if safetensor_path.is_file():
|
||||
try:
|
||||
from safetensors.torch import load_file # type: ignore
|
||||
except Exception as exc: # pragma: no cover - import guard
|
||||
raise RuntimeError(
|
||||
"Install safetensors to load model.safetensors "
|
||||
"(pip install safetensors)"
|
||||
) from exc
|
||||
return load_file(str(safetensor_path))
|
||||
|
||||
if bin_path.is_file():
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||
)
|
||||
|
||||
|
||||
def _load_config(repo_path: Path) -> Dict:
|
||||
config_path = repo_path / "config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||
)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||
expected_samples = chunk_length * SAMPLE_RATE
|
||||
if abs(n_samples - expected_samples) > 1e-6:
|
||||
raise ValueError(
|
||||
"chunk_length must align with sample rate so that "
|
||||
"chunk_length * SAMPLE_RATE is an integer"
|
||||
)
|
||||
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||
n_audio_ctx = exact_div(n_frames, 2)
|
||||
return n_frames, n_audio_ctx
|
||||
|
||||
|
||||
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||
base_dims = ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
).__dict__.copy()
|
||||
|
||||
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||
base_dims["chunk_length"] = chunk_length
|
||||
return base_dims
|
||||
|
||||
|
||||
def _trim_positional_embedding(
|
||||
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||
) -> None:
|
||||
key = "encoder.positional_embedding"
|
||||
if key not in state_dict:
|
||||
raise KeyError(f"{key} missing from converted state dict")
|
||||
|
||||
tensor = state_dict[key]
|
||||
if tensor.shape[0] < target_ctx:
|
||||
raise ValueError(
|
||||
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||
)
|
||||
if tensor.shape[0] == target_ctx:
|
||||
return
|
||||
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||
|
||||
|
||||
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||
state_dict = _load_state_dict(hf_path)
|
||||
converted = _convert_hf_state_dict(state_dict)
|
||||
|
||||
config = _load_config(hf_path)
|
||||
dims = _build_dims(config, chunk_length)
|
||||
|
||||
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||
|
||||
package = {"dims": dims, "model_state_dict": converted}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(package, output_path)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"hf_path",
|
||||
type=str,
|
||||
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="converted-whisper.pt",
|
||||
help="Destination path for the .pt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Audio chunk length in seconds to support (default: 30)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||
|
||||
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||
print(f"Saved converted checkpoint to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
294
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Determine alignment heads for a variants, such as distilled model"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import math
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio
|
||||
from datasets import load_dataset
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
sys.path.insert(0, str(WHISPER_ROOT))
|
||||
|
||||
from whisper import load_model
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def load_dataset_clips(name, config, split, limit):
|
||||
ds = load_dataset(name, config, split=split)
|
||||
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||
clips = []
|
||||
for idx, row in enumerate(ds):
|
||||
if limit is not None and idx >= limit:
|
||||
break
|
||||
audio_field = row["audio"]
|
||||
transcript = row["text"]
|
||||
|
||||
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||
if waveform_np.ndim > 1:
|
||||
waveform_np = waveform_np.mean(axis=1)
|
||||
waveform = waveform_np
|
||||
transcript = str(transcript)
|
||||
|
||||
clips.append((waveform, transcript))
|
||||
return clips
|
||||
|
||||
|
||||
def load_clips(args):
|
||||
return load_dataset_clips(
|
||||
args.dataset,
|
||||
args.dataset_config,
|
||||
args.dataset_split,
|
||||
args.dataset_num_samples,
|
||||
)
|
||||
|
||||
|
||||
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||
return waveform
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Torch device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="librispeech_asr"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config",
|
||||
type=str,
|
||||
default="clean"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="validation[:1%]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-num-samples",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Z score threshold for a head to be selected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--votes",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="percentage of clips that must vote for a head",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="alignment_heads.b85",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--visualize-top-k",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def collect_heads(
|
||||
model,
|
||||
tokenizer,
|
||||
clips: Sequence[Tuple[AudioInput, str]],
|
||||
threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = model.device
|
||||
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||
strengths = torch.zeros_like(votes)
|
||||
|
||||
for audio_source, transcript in clips:
|
||||
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||
mel = log_mel_spectrogram(waveform, device=device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
qks = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
for layer_idx, tensor in enumerate(qks):
|
||||
if tensor is None:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1)
|
||||
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||
strengths[layer_idx] += peak.mean(dim=-1)
|
||||
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||
)
|
||||
mask = (zscore > 3).any(dim=-1)
|
||||
votes[layer_idx] += mask.float()
|
||||
|
||||
votes /= len(clips)
|
||||
strengths /= len(clips)
|
||||
return votes, strengths
|
||||
|
||||
|
||||
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||
selected = torch.nonzero(selection, as_tuple=False)
|
||||
if selected.numel() == 0:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||
for layer, head in selected
|
||||
]
|
||||
entries.sort(key=lambda item: item[2], reverse=True)
|
||||
return entries[:top_k]
|
||||
|
||||
def _extract_heatmaps(
|
||||
model,
|
||||
tokenizer,
|
||||
clip: Tuple[AudioInput, str],
|
||||
heads: Sequence[Tuple[int, int, float]],
|
||||
) -> dict:
|
||||
if not heads:
|
||||
return {}
|
||||
|
||||
target_map = {}
|
||||
for layer, head, _ in heads:
|
||||
target_map.setdefault(layer, set()).add(head)
|
||||
|
||||
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||
transcript = clip[1]
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
heatmaps = {}
|
||||
for layer_idx, tensor in enumerate(QKs):
|
||||
if tensor is None or layer_idx not in target_map:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1).cpu()
|
||||
for head_idx in target_map[layer_idx]:
|
||||
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
def _plot_heatmaps(
|
||||
heads, heatmaps, output_path):
|
||||
cols = min(3, len(heads))
|
||||
rows = math.ceil(len(heads) / cols)
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||
|
||||
for idx, (layer, head, score) in enumerate(heads):
|
||||
ax = axes[idx // cols][idx % cols]
|
||||
mat = heatmaps.get((layer, head))
|
||||
if mat is None:
|
||||
ax.axis("off")
|
||||
continue
|
||||
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("tokens")
|
||||
|
||||
for j in range(len(heads), rows * cols):
|
||||
axes[j // cols][j % cols].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||
payload = mask.numpy().astype(np.bool_)
|
||||
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
model = load_model(args.model, device=args.device)
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||
clips = load_clips(args)
|
||||
|
||||
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||
# selection = votes > 0.5
|
||||
selection = strengths > 0.05
|
||||
_dump_mask(selection.cpu(), args.output)
|
||||
|
||||
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
580
scripts/python_support_matrix.py
Normal file
@@ -0,0 +1,580 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Offline Python support matrix runner for WhisperLiveKit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
HAS_RICH = True
|
||||
except Exception:
|
||||
HAS_RICH = False
|
||||
|
||||
SAMPLE_URL = (
|
||||
"https://github.com/pyannote/pyannote-audio/raw/develop/tutorials/assets/sample.wav"
|
||||
)
|
||||
SAMPLE_PATH = Path("audio_tests/support-matrix-sample.wav")
|
||||
DEFAULT_LOGS_DIR = Path("outputs/python-matrix/logs")
|
||||
PYTHON_VERSIONS = ("3.11", "3.12", "3.13")
|
||||
CONSOLE = Console() if HAS_RICH else None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MatrixRow:
|
||||
row_id: str
|
||||
extras: tuple[str, ...]
|
||||
backend: str
|
||||
policy: str
|
||||
diarization_backend: str
|
||||
requires_gpu: bool = False
|
||||
|
||||
|
||||
CASES = (
|
||||
MatrixRow(
|
||||
row_id="fw-diart-cpu",
|
||||
extras=("test", "cpu", "diarization-diart"),
|
||||
backend="faster-whisper",
|
||||
policy="simulstreaming",
|
||||
diarization_backend="diart",
|
||||
),
|
||||
MatrixRow(
|
||||
row_id="fw-sortformer-cpu",
|
||||
extras=("test", "cpu", "diarization-sortformer"),
|
||||
backend="faster-whisper",
|
||||
policy="simulstreaming",
|
||||
diarization_backend="sortformer",
|
||||
),
|
||||
MatrixRow(
|
||||
row_id="fw-sortformer-gpu",
|
||||
extras=("test", "cu129", "diarization-sortformer"),
|
||||
backend="faster-whisper",
|
||||
policy="simulstreaming",
|
||||
diarization_backend="sortformer",
|
||||
requires_gpu=True,
|
||||
),
|
||||
MatrixRow(
|
||||
row_id="voxtral-diart-cpu",
|
||||
extras=("test", "cpu", "voxtral-hf", "diarization-diart"),
|
||||
backend="voxtral",
|
||||
policy="voxtral",
|
||||
diarization_backend="diart",
|
||||
),
|
||||
)
|
||||
|
||||
EXPECTED_FAILURE_CASES = {
|
||||
("3.11", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||
("3.12", "voxtral-diart-cpu"): "known_unstable_voxtral_diart_cpu",
|
||||
}
|
||||
UNSUPPORTED_CASES = {
|
||||
("3.13", "fw-sortformer-cpu"): "unsupported_py313_sortformer_protobuf",
|
||||
("3.13", "fw-sortformer-gpu"): "unsupported_py313_sortformer_protobuf",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaseResult:
|
||||
python_version: str
|
||||
row_id: str
|
||||
status: Literal["PASS", "FAIL", "N/A"]
|
||||
reason: str
|
||||
duration_sec: float
|
||||
hint: str = ""
|
||||
log_path: str = ""
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Minimal WhisperLiveKit offline support matrix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout-sec",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Per-case timeout in seconds (default: 300)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
default=str(DEFAULT_LOGS_DIR),
|
||||
help="Directory where per-case logs are written (default: outputs/python-matrix/logs)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def safe_slug(text: str) -> str:
|
||||
return text.replace("=", "-").replace("|", "__").replace("/", "-").replace(" ", "-")
|
||||
|
||||
|
||||
def status_style(status: str) -> str:
|
||||
if status == "PASS":
|
||||
return "green"
|
||||
if status == "FAIL":
|
||||
return "bold red"
|
||||
if status == "N/A":
|
||||
return "yellow"
|
||||
return "white"
|
||||
|
||||
|
||||
def print_line(message: str, style: str | None = None) -> None:
|
||||
if CONSOLE is None:
|
||||
print(message)
|
||||
return
|
||||
if style:
|
||||
CONSOLE.print(message, style=style, highlight=False)
|
||||
else:
|
||||
CONSOLE.print(message, highlight=False)
|
||||
|
||||
|
||||
def tail_text(text: str | None, max_chars: int = 220) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= max_chars:
|
||||
return normalized
|
||||
return normalized[-max_chars:]
|
||||
|
||||
|
||||
def run_command(
|
||||
cmd: list[str],
|
||||
cwd: Path,
|
||||
env: dict[str, str],
|
||||
timeout: int | None = None,
|
||||
log_path: Path | None = None,
|
||||
log_section: str | None = None,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
def _append_log(
|
||||
*,
|
||||
command: list[str],
|
||||
section: str,
|
||||
returncode: int | None,
|
||||
stdout: str | None,
|
||||
stderr: str | None,
|
||||
timed_out: bool = False,
|
||||
) -> None:
|
||||
if log_path is None:
|
||||
return
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with log_path.open("a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== {section} ===\n")
|
||||
f.write(f"$ {shlex.join(command)}\n")
|
||||
if timed_out:
|
||||
f.write("status: timeout\n")
|
||||
else:
|
||||
f.write(f"status: exit_code={returncode}\n")
|
||||
if stdout:
|
||||
f.write("--- stdout ---\n")
|
||||
f.write(stdout)
|
||||
if not stdout.endswith("\n"):
|
||||
f.write("\n")
|
||||
if stderr:
|
||||
f.write("--- stderr ---\n")
|
||||
f.write(stderr)
|
||||
if not stderr.endswith("\n"):
|
||||
f.write("\n")
|
||||
|
||||
section = log_section or "command"
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
timeout=timeout,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
_append_log(
|
||||
command=cmd,
|
||||
section=section,
|
||||
returncode=None,
|
||||
stdout=exc.stdout if isinstance(exc.stdout, str) else None,
|
||||
stderr=exc.stderr if isinstance(exc.stderr, str) else None,
|
||||
timed_out=True,
|
||||
)
|
||||
raise
|
||||
|
||||
_append_log(
|
||||
command=cmd,
|
||||
section=section,
|
||||
returncode=proc.returncode,
|
||||
stdout=proc.stdout,
|
||||
stderr=proc.stderr,
|
||||
)
|
||||
return proc
|
||||
|
||||
|
||||
def detect_gpu_available() -> bool:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["nvidia-smi", "-L"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
timeout=10,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return False
|
||||
return proc.returncode == 0
|
||||
|
||||
|
||||
def download_sample(repo_root: Path) -> Path:
|
||||
target = repo_root / SAMPLE_PATH
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"curl",
|
||||
"--fail",
|
||||
"--location",
|
||||
"--silent",
|
||||
"--show-error",
|
||||
SAMPLE_URL,
|
||||
"--output",
|
||||
str(target),
|
||||
]
|
||||
proc = run_command(cmd, cwd=repo_root, env=os.environ.copy())
|
||||
if proc.returncode != 0:
|
||||
hint = tail_text(proc.stderr or proc.stdout)
|
||||
raise RuntimeError(f"sample_download_failed: {hint}")
|
||||
return target
|
||||
|
||||
|
||||
def sync_case_environment(
|
||||
repo_root: Path,
|
||||
python_version: str,
|
||||
row: MatrixRow,
|
||||
env_dir: Path,
|
||||
log_path: Path,
|
||||
) -> tuple[bool, str]:
|
||||
cmd = ["uv", "sync", "--python", python_version, "--no-dev"]
|
||||
for extra in row.extras:
|
||||
cmd.extend(["--extra", extra])
|
||||
env = os.environ.copy()
|
||||
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||
proc = run_command(
|
||||
cmd,
|
||||
cwd=repo_root,
|
||||
env=env,
|
||||
log_path=log_path,
|
||||
log_section="sync",
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return False, tail_text(proc.stderr or proc.stdout)
|
||||
return True, ""
|
||||
|
||||
|
||||
def apply_expected_failure_policy(result: CaseResult) -> CaseResult:
|
||||
expected_reason = EXPECTED_FAILURE_CASES.get((result.python_version, result.row_id))
|
||||
if result.status != "FAIL" or not expected_reason:
|
||||
return result
|
||||
override_hint = result.hint
|
||||
if result.reason:
|
||||
override_hint = (
|
||||
f"expected_failure_override original_reason={result.reason}; {override_hint}"
|
||||
if override_hint
|
||||
else f"expected_failure_override original_reason={result.reason}"
|
||||
)
|
||||
return CaseResult(
|
||||
python_version=result.python_version,
|
||||
row_id=result.row_id,
|
||||
status="N/A",
|
||||
reason=expected_reason,
|
||||
duration_sec=result.duration_sec,
|
||||
hint=override_hint,
|
||||
log_path=result.log_path,
|
||||
)
|
||||
|
||||
|
||||
def build_offline_command(
|
||||
python_version: str,
|
||||
row: MatrixRow,
|
||||
sample_audio: Path,
|
||||
timeout_sec: int,
|
||||
) -> tuple[list[str], int | None]:
|
||||
base_cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--python",
|
||||
python_version,
|
||||
"--no-sync",
|
||||
"python",
|
||||
"test_backend_offline.py",
|
||||
"--backend",
|
||||
row.backend,
|
||||
"--policy",
|
||||
row.policy,
|
||||
"--audio",
|
||||
str(sample_audio),
|
||||
"--model",
|
||||
"tiny",
|
||||
"--diarization",
|
||||
"--diarization-backend",
|
||||
row.diarization_backend,
|
||||
"--lan",
|
||||
"en",
|
||||
"--no-realtime",
|
||||
]
|
||||
if shutil.which("timeout"):
|
||||
return ["timeout", str(timeout_sec), *base_cmd], None
|
||||
return base_cmd, timeout_sec
|
||||
|
||||
|
||||
def run_case(
|
||||
repo_root: Path,
|
||||
python_version: str,
|
||||
row: MatrixRow,
|
||||
sample_audio: Path,
|
||||
timeout_sec: int,
|
||||
gpu_available: bool,
|
||||
logs_dir: Path,
|
||||
) -> CaseResult:
|
||||
start = time.monotonic()
|
||||
case_slug = safe_slug(f"py{python_version}-{row.row_id}")
|
||||
log_path = logs_dir / f"run-{case_slug}.log"
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_path.write_text("", encoding="utf-8")
|
||||
|
||||
unsupported_reason = UNSUPPORTED_CASES.get((python_version, row.row_id))
|
||||
if unsupported_reason:
|
||||
log_path.write_text(
|
||||
f"[matrix] precheck_short_circuit status=N/A reason={unsupported_reason}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="N/A",
|
||||
reason=unsupported_reason,
|
||||
duration_sec=0.0,
|
||||
hint="unsupported_case_precheck",
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
if row.requires_gpu and not gpu_available:
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="N/A",
|
||||
reason="gpu_unavailable",
|
||||
duration_sec=0.0,
|
||||
hint="nvidia-smi unavailable or failed",
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
env_dir = repo_root / ".matrix-envs" / safe_slug(f"py{python_version}-{row.row_id}")
|
||||
sync_ok, sync_hint = sync_case_environment(
|
||||
repo_root,
|
||||
python_version,
|
||||
row,
|
||||
env_dir,
|
||||
log_path=log_path,
|
||||
)
|
||||
if not sync_ok:
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="FAIL",
|
||||
reason="dependency_sync_failed",
|
||||
duration_sec=round(time.monotonic() - start, 3),
|
||||
hint=sync_hint,
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
cmd, process_timeout = build_offline_command(
|
||||
python_version, row, sample_audio, timeout_sec
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["UV_PROJECT_ENVIRONMENT"] = str(env_dir)
|
||||
if row.requires_gpu:
|
||||
env.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
else:
|
||||
env["CUDA_VISIBLE_DEVICES"] = ""
|
||||
try:
|
||||
proc = run_command(
|
||||
cmd,
|
||||
cwd=repo_root,
|
||||
env=env,
|
||||
timeout=process_timeout,
|
||||
log_path=log_path,
|
||||
log_section="offline",
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="FAIL",
|
||||
reason="offline_timeout",
|
||||
duration_sec=round(time.monotonic() - start, 3),
|
||||
hint=tail_text((exc.stderr or "") if isinstance(exc.stderr, str) else ""),
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
hint = tail_text(proc.stderr or proc.stdout)
|
||||
if proc.returncode == 0:
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="PASS",
|
||||
reason="ok",
|
||||
duration_sec=round(time.monotonic() - start, 3),
|
||||
hint=hint,
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
reason = "offline_timeout" if proc.returncode == 124 else "offline_run_failed"
|
||||
return CaseResult(
|
||||
python_version=python_version,
|
||||
row_id=row.row_id,
|
||||
status="FAIL",
|
||||
reason=reason,
|
||||
duration_sec=round(time.monotonic() - start, 3),
|
||||
hint=hint,
|
||||
log_path=str(log_path),
|
||||
)
|
||||
|
||||
|
||||
def print_summary(results: list[CaseResult]) -> None:
|
||||
pass_count = sum(1 for row in results if row.status == "PASS")
|
||||
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||
na_count = sum(1 for row in results if row.status == "N/A")
|
||||
if CONSOLE is None:
|
||||
print("\n[matrix] results")
|
||||
print("python | row | status | reason | duration_s")
|
||||
print("---|---|---|---|---")
|
||||
for result in results:
|
||||
print(
|
||||
f"{result.python_version} | {result.row_id} | {result.status} | "
|
||||
f"{result.reason} | {result.duration_sec:.3f}"
|
||||
)
|
||||
print(
|
||||
f"\n[matrix] summary pass={pass_count} fail={fail_count} "
|
||||
f"na={na_count} total={len(results)}"
|
||||
)
|
||||
else:
|
||||
table = Table(title="Support Matrix Results")
|
||||
table.add_column("Python", style="cyan", no_wrap=True)
|
||||
table.add_column("Row", style="white")
|
||||
table.add_column("Status", no_wrap=True)
|
||||
table.add_column("Reason")
|
||||
table.add_column("Duration (s)", justify="right", no_wrap=True)
|
||||
for result in results:
|
||||
table.add_row(
|
||||
result.python_version,
|
||||
result.row_id,
|
||||
f"[{status_style(result.status)}]{result.status}[/{status_style(result.status)}]",
|
||||
result.reason,
|
||||
f"{result.duration_sec:.3f}",
|
||||
)
|
||||
CONSOLE.print()
|
||||
CONSOLE.print(table)
|
||||
CONSOLE.print(
|
||||
f"[bold]Summary[/bold] "
|
||||
f"pass=[green]{pass_count}[/green] "
|
||||
f"fail=[bold red]{fail_count}[/bold red] "
|
||||
f"na=[yellow]{na_count}[/yellow] "
|
||||
f"total={len(results)}"
|
||||
)
|
||||
|
||||
diagnostics = [row for row in results if row.status in {"FAIL", "N/A"} and row.hint]
|
||||
if diagnostics:
|
||||
if CONSOLE is None:
|
||||
print("\n[matrix] diagnostics (failed/n-a cases)")
|
||||
for row in diagnostics:
|
||||
print(
|
||||
f"- py={row.python_version} row={row.row_id} "
|
||||
f"status={row.status} reason={row.reason}"
|
||||
)
|
||||
print(f" hint: {row.hint}")
|
||||
if row.log_path:
|
||||
print(f" log: {row.log_path}")
|
||||
else:
|
||||
diagnostics_table = Table(title="Diagnostics (FAIL / N/A)")
|
||||
diagnostics_table.add_column("Case", style="cyan")
|
||||
diagnostics_table.add_column("Status", no_wrap=True)
|
||||
diagnostics_table.add_column("Reason")
|
||||
diagnostics_table.add_column("Hint")
|
||||
diagnostics_table.add_column("Log")
|
||||
for row in diagnostics:
|
||||
diagnostics_table.add_row(
|
||||
f"py={row.python_version} {row.row_id}",
|
||||
f"[{status_style(row.status)}]{row.status}[/{status_style(row.status)}]",
|
||||
row.reason,
|
||||
row.hint,
|
||||
row.log_path,
|
||||
)
|
||||
CONSOLE.print()
|
||||
CONSOLE.print(diagnostics_table)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
if args.timeout_sec <= 0:
|
||||
print("[matrix] error: --timeout-sec must be > 0", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
logs_dir = (repo_root / args.logs_dir).resolve()
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
print_line(f"[matrix] repo_root={repo_root}", style="cyan")
|
||||
print_line(f"[matrix] timeout_sec={args.timeout_sec}", style="cyan")
|
||||
print_line(f"[matrix] logs_dir={logs_dir}", style="cyan")
|
||||
|
||||
try:
|
||||
sample_audio = download_sample(repo_root)
|
||||
except Exception as exc: # pragma: no cover - straightforward failure path
|
||||
if CONSOLE is None:
|
||||
print(f"[matrix] sample_download_failed: {exc}", file=sys.stderr)
|
||||
else:
|
||||
CONSOLE.print(
|
||||
f"[matrix] sample_download_failed: {exc}",
|
||||
style="bold red",
|
||||
highlight=False,
|
||||
)
|
||||
return 1
|
||||
print_line(f"[matrix] sample_audio={sample_audio}", style="cyan")
|
||||
|
||||
gpu_available = detect_gpu_available()
|
||||
print_line(f"[matrix] gpu_available={gpu_available}", style="cyan")
|
||||
|
||||
results: list[CaseResult] = []
|
||||
for python_version in PYTHON_VERSIONS:
|
||||
for row in CASES:
|
||||
print_line(
|
||||
f"\n[matrix] running py={python_version} row={row.row_id}", style="blue"
|
||||
)
|
||||
result = run_case(
|
||||
repo_root=repo_root,
|
||||
python_version=python_version,
|
||||
row=row,
|
||||
sample_audio=sample_audio,
|
||||
timeout_sec=args.timeout_sec,
|
||||
gpu_available=gpu_available,
|
||||
logs_dir=logs_dir,
|
||||
)
|
||||
result = apply_expected_failure_policy(result)
|
||||
results.append(result)
|
||||
print_line(
|
||||
f"[matrix] {result.status} py={result.python_version} "
|
||||
f"row={result.row_id} reason={result.reason} duration={result.duration_sec:.3f}s",
|
||||
style=status_style(result.status),
|
||||
)
|
||||
if result.log_path:
|
||||
print_line(f"[matrix] log={result.log_path}", style="dim")
|
||||
|
||||
print_summary(results)
|
||||
fail_count = sum(1 for row in results if row.status == "FAIL")
|
||||
return 1 if fail_count else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
40
scripts/sync_extension.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sync_extension_files():
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
extension_dir = Path("chrome-extension")
|
||||
|
||||
files_to_sync = [
|
||||
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||
]
|
||||
|
||||
svg_files = [
|
||||
"system_mode.svg",
|
||||
"light_mode.svg",
|
||||
"dark_mode.svg",
|
||||
"settings.svg"
|
||||
]
|
||||
|
||||
for file in files_to_sync:
|
||||
src_path = web_dir / file
|
||||
dest_path = extension_dir / file
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
for svg_file in svg_files:
|
||||
src_path = web_dir / "src" / svg_file
|
||||
dest_path = extension_dir / "web" / "src" / svg_file
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sync_extension_files()
|
||||
55
setup.py
@@ -1,55 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name="whisperlivekit",
|
||||
version="0.2.1",
|
||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
author="Quentin Fuxa",
|
||||
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
],
|
||||
extras_require={
|
||||
"diarization": ["diart"],
|
||||
"vac": ["torch"],
|
||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
||||
"whisper": ["whisper"],
|
||||
"whisper-timestamped": ["whisper-timestamped"],
|
||||
"mlx-whisper": ["mlx-whisper"],
|
||||
"openai": ["openai"],
|
||||
"simulstreaming": [
|
||||
"torch",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
"numpy<2.0.0",
|
||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
||||
],
|
||||
},
|
||||
package_data={
|
||||
'whisperlivekit': ['web/*.html'],
|
||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
||||
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
|
||||
},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||
],
|
||||
python_requires=">=3.9",
|
||||
)
|
||||
803
test_backend_offline.py
Normal file
@@ -0,0 +1,803 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Offline test harness and benchmark suite for WhisperLiveKit backends.
|
||||
|
||||
Simulates a client-server session by feeding audio files as PCM bytes through
|
||||
the full AudioProcessor pipeline (the same path used by the WebSocket server),
|
||||
without needing a browser or microphone.
|
||||
|
||||
Computes WER (Word Error Rate) and timestamp accuracy when ground truth
|
||||
transcript files (.transcript.json) are available alongside audio files.
|
||||
|
||||
Usage:
|
||||
# Test with a single audio file:
|
||||
python test_backend_offline.py --backend faster-whisper --audio audio_tests/00_00_07_english_1_speaker.wav
|
||||
|
||||
# Test all files in audio_tests/:
|
||||
python test_backend_offline.py --backend faster-whisper --no-realtime
|
||||
|
||||
# Override streaming policy:
|
||||
python test_backend_offline.py --backend faster-whisper --policy simulstreaming --no-realtime
|
||||
|
||||
# Multi-backend benchmark (auto-detects all installed backends):
|
||||
python test_backend_offline.py --benchmark --no-realtime
|
||||
|
||||
# Export results as JSON:
|
||||
python test_backend_offline.py --benchmark --no-realtime --json results.json
|
||||
|
||||
# Insert silence for testing silence handling:
|
||||
python test_backend_offline.py --backend faster-whisper --insert-silence 3.0 2.0
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("test_offline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
JFK_WAV_URL = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||
CACHE_DIR = Path(__file__).parent / ".test_cache"
|
||||
AUDIO_TESTS_DIR = Path(__file__).parent / "audio_tests"
|
||||
AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTimestamp:
|
||||
"""Word with its start/end time."""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""Structured result from a single test run."""
|
||||
audio_file: str
|
||||
audio_duration_s: float
|
||||
backend: str
|
||||
policy: str
|
||||
language: str
|
||||
chunk_ms: int
|
||||
realtime_pacing: bool
|
||||
# Timing
|
||||
processing_time_s: float
|
||||
rtf: float # real-time factor
|
||||
# Transcription output
|
||||
transcription: str
|
||||
n_lines: int
|
||||
n_responses: int
|
||||
# WER metrics (None if no ground truth)
|
||||
wer: Optional[float] = None
|
||||
wer_details: Optional[dict] = None
|
||||
# Timestamp accuracy (None if no ground truth)
|
||||
timestamp_mae: Optional[float] = None
|
||||
timestamp_max_delta: Optional[float] = None
|
||||
timestamp_median_delta: Optional[float] = None
|
||||
# Word-level timestamps
|
||||
word_timestamps: List[WordTimestamp] = field(default_factory=list)
|
||||
# Raw last response
|
||||
last_response: Optional[dict] = None
|
||||
|
||||
|
||||
def download_sample_audio() -> Path:
|
||||
"""Download the jfk.wav sample if not cached."""
|
||||
CACHE_DIR.mkdir(exist_ok=True)
|
||||
path = CACHE_DIR / "jfk.wav"
|
||||
if not path.exists():
|
||||
logger.info(f"Downloading sample audio to {path} ...")
|
||||
urllib.request.urlretrieve(JFK_WAV_URL, path)
|
||||
logger.info("Done.")
|
||||
return path
|
||||
|
||||
|
||||
def load_audio(path: str) -> np.ndarray:
|
||||
"""Load audio file as float32 mono 16kHz numpy array.
|
||||
|
||||
Supports WAV, FLAC (via soundfile) and MP3, OGG, M4A (via librosa).
|
||||
"""
|
||||
ext = Path(path).suffix.lower()
|
||||
if ext in (".mp3", ".ogg", ".m4a"):
|
||||
import librosa
|
||||
audio, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
|
||||
return audio.astype(np.float32)
|
||||
|
||||
import soundfile as sf
|
||||
audio, sr = sf.read(path, dtype="float32")
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
if sr != SAMPLE_RATE:
|
||||
import librosa
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
|
||||
return audio
|
||||
|
||||
|
||||
def insert_silence(audio: np.ndarray, silence_sec: float, position_sec: float) -> np.ndarray:
|
||||
"""Insert silence into audio at a given position.
|
||||
|
||||
Args:
|
||||
audio: Float32 mono audio array at SAMPLE_RATE.
|
||||
silence_sec: Duration of silence to insert in seconds.
|
||||
position_sec: Position in seconds where silence starts.
|
||||
Returns:
|
||||
New audio array with silence inserted.
|
||||
"""
|
||||
pos_samples = int(position_sec * SAMPLE_RATE)
|
||||
silence_samples = int(silence_sec * SAMPLE_RATE)
|
||||
pos_samples = min(pos_samples, len(audio))
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio[:pos_samples], silence, audio[pos_samples:]])
|
||||
|
||||
|
||||
def float32_to_s16le_bytes(audio: np.ndarray) -> bytes:
|
||||
"""Convert float32 audio to s16le PCM bytes (what the browser sends)."""
|
||||
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||
|
||||
|
||||
def create_engine(
|
||||
backend: str, model_size: str, lan: str,
|
||||
diarization: bool = False,
|
||||
diarization_backend: str = "",
|
||||
vac: bool = True,
|
||||
policy: str = "",
|
||||
):
|
||||
"""Create a TranscriptionEngine with the given backend config."""
|
||||
import gc
|
||||
from whisperlivekit.core import TranscriptionEngine
|
||||
|
||||
# Reset singleton so we get a fresh instance
|
||||
TranscriptionEngine._instance = None
|
||||
TranscriptionEngine._initialized = False
|
||||
gc.collect()
|
||||
|
||||
kwargs = dict(
|
||||
backend=backend,
|
||||
lan=lan,
|
||||
pcm_input=True,
|
||||
vac=vac,
|
||||
transcription=True,
|
||||
diarization=diarization,
|
||||
)
|
||||
if diarization_backend:
|
||||
kwargs["diarization_backend"] = diarization_backend
|
||||
if model_size:
|
||||
kwargs["model_size"] = model_size
|
||||
if policy:
|
||||
kwargs["backend_policy"] = policy
|
||||
|
||||
return TranscriptionEngine(**kwargs)
|
||||
|
||||
|
||||
def _extract_text_from_response(response_dict: dict) -> str:
|
||||
"""Extract full transcription text from a FrontData dict."""
|
||||
def _strip_or_empty(value: object) -> str:
|
||||
return value.strip() if isinstance(value, str) else ""
|
||||
|
||||
segments = response_dict.get("lines", [])
|
||||
full_text = " ".join(
|
||||
text
|
||||
for seg in segments
|
||||
if isinstance(seg, dict)
|
||||
for text in [_strip_or_empty(seg.get("text"))]
|
||||
if text
|
||||
)
|
||||
buf = _strip_or_empty(response_dict.get("buffer_transcription"))
|
||||
if buf:
|
||||
full_text = f"{full_text} {buf}".strip() if full_text else buf
|
||||
return full_text
|
||||
|
||||
|
||||
async def run_test(
|
||||
engine, audio: np.ndarray, chunk_ms: int, realtime: bool,
|
||||
audio_file: str = "", backend: str = "", policy: str = "", lan: str = "",
|
||||
) -> TestResult:
|
||||
"""
|
||||
Simulate a client session through the full AudioProcessor pipeline.
|
||||
|
||||
1. Create AudioProcessor (one per "client session")
|
||||
2. Start async pipeline (transcription_processor, results_formatter, etc.)
|
||||
3. Feed audio as PCM bytes in timed chunks
|
||||
4. Collect and display FrontData responses
|
||||
5. Signal EOF and cleanup
|
||||
"""
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
chunk_samples = int(SAMPLE_RATE * chunk_ms / 1000)
|
||||
total_samples = len(audio)
|
||||
audio_duration = total_samples / SAMPLE_RATE
|
||||
|
||||
logger.info(
|
||||
f"Audio: {audio_duration:.2f}s | "
|
||||
f"Chunk: {chunk_ms}ms ({chunk_samples} samples) | "
|
||||
f"Steps: {total_samples // chunk_samples + 1} | "
|
||||
f"Realtime: {realtime}"
|
||||
)
|
||||
|
||||
# --- Server side: create processor and start pipeline ---
|
||||
processor = AudioProcessor(transcription_engine=engine)
|
||||
results_generator = await processor.create_tasks()
|
||||
|
||||
# Collect results in background (like handle_websocket_results)
|
||||
all_responses = []
|
||||
response_count = 0
|
||||
last_printed_text = ""
|
||||
|
||||
async def collect_results():
|
||||
nonlocal response_count, last_printed_text
|
||||
async for response in results_generator:
|
||||
all_responses.append(response)
|
||||
response_count += 1
|
||||
d = response.to_dict()
|
||||
|
||||
# Only print when transcription text actually changes
|
||||
current_text = _extract_text_from_response(d)
|
||||
if current_text and current_text != last_printed_text:
|
||||
buf = d.get("buffer_transcription")
|
||||
buf = buf.strip() if isinstance(buf, str) else ""
|
||||
committed = current_text
|
||||
if buf and committed.endswith(buf):
|
||||
committed = committed[:-len(buf)].strip()
|
||||
|
||||
# Show committed text + buffer separately
|
||||
display = committed
|
||||
if buf:
|
||||
display = f"{committed} \033[90m{buf}\033[0m" if committed else f"\033[90m{buf}\033[0m"
|
||||
print(f" > {display}", flush=True)
|
||||
last_printed_text = current_text
|
||||
|
||||
result_task = asyncio.create_task(collect_results())
|
||||
|
||||
# --- Client side: feed audio as PCM bytes ---
|
||||
t_start = time.time()
|
||||
|
||||
for offset in range(0, total_samples, chunk_samples):
|
||||
chunk = audio[offset : offset + chunk_samples]
|
||||
pcm_bytes = float32_to_s16le_bytes(chunk)
|
||||
await processor.process_audio(pcm_bytes)
|
||||
if realtime:
|
||||
await asyncio.sleep(chunk_ms / 1000)
|
||||
|
||||
feed_elapsed = time.time() - t_start
|
||||
|
||||
logger.info(f"Audio fed in {feed_elapsed:.2f}s. Signaling EOF...")
|
||||
|
||||
# Signal end of audio (like client disconnect / empty message)
|
||||
await processor.process_audio(None)
|
||||
|
||||
# Wait for pipeline to drain completely
|
||||
try:
|
||||
await asyncio.wait_for(result_task, timeout=120.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for results. Proceeding with cleanup.")
|
||||
result_task.cancel()
|
||||
try:
|
||||
await result_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# --- Capture word-level timestamps before cleanup ---
|
||||
word_timestamps = []
|
||||
try:
|
||||
state = await processor.get_current_state()
|
||||
for token in state.tokens:
|
||||
if hasattr(token, 'start') and hasattr(token, 'text') and token.text:
|
||||
word_timestamps.append(WordTimestamp(
|
||||
word=token.text.strip(),
|
||||
start=round(token.start, 3),
|
||||
end=round(token.end, 3),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not capture word timestamps: {e}")
|
||||
|
||||
# Cleanup
|
||||
await processor.cleanup()
|
||||
|
||||
total_elapsed = time.time() - t_start
|
||||
|
||||
# --- Build result ---
|
||||
transcription = ""
|
||||
n_lines = 0
|
||||
last_response_dict = None
|
||||
|
||||
if all_responses:
|
||||
last = all_responses[-1].to_dict()
|
||||
last_response_dict = last
|
||||
n_lines = len(last.get("lines", []))
|
||||
transcription = _extract_text_from_response(last)
|
||||
|
||||
# --- Compute WER and timestamp accuracy against ground truth ---
|
||||
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
|
||||
|
||||
wer_val = None
|
||||
wer_details = None
|
||||
ts_mae = None
|
||||
ts_max_delta = None
|
||||
ts_median_delta = None
|
||||
|
||||
gt_path = Path(audio_file).with_suffix(".transcript.json")
|
||||
if not gt_path.exists():
|
||||
gt_path = AUDIO_TESTS_DIR / gt_path
|
||||
gt = None
|
||||
if gt_path.exists():
|
||||
with open(gt_path) as f:
|
||||
gt = json.load(f)
|
||||
|
||||
# WER
|
||||
gt_text = " ".join(w["word"] for w in gt)
|
||||
wer_result = compute_wer(gt_text, transcription)
|
||||
wer_val = round(wer_result["wer"], 4)
|
||||
wer_details = wer_result
|
||||
|
||||
# Timestamp accuracy
|
||||
if word_timestamps:
|
||||
pred_dicts = [{"word": wt.word, "start": wt.start, "end": wt.end} for wt in word_timestamps]
|
||||
ts_result = compute_timestamp_accuracy(pred_dicts, gt)
|
||||
ts_mae = ts_result["mae_start"]
|
||||
ts_max_delta = ts_result["max_delta_start"]
|
||||
ts_median_delta = ts_result["median_delta_start"]
|
||||
|
||||
result = TestResult(
|
||||
audio_file=audio_file,
|
||||
audio_duration_s=round(audio_duration, 2),
|
||||
backend=backend,
|
||||
policy=policy,
|
||||
language=lan,
|
||||
chunk_ms=chunk_ms,
|
||||
realtime_pacing=realtime,
|
||||
processing_time_s=round(total_elapsed, 2),
|
||||
rtf=round(total_elapsed / audio_duration, 2),
|
||||
transcription=transcription,
|
||||
n_lines=n_lines,
|
||||
n_responses=response_count,
|
||||
wer=wer_val,
|
||||
wer_details=wer_details,
|
||||
timestamp_mae=round(ts_mae, 3) if ts_mae is not None else None,
|
||||
timestamp_max_delta=round(ts_max_delta, 3) if ts_max_delta is not None else None,
|
||||
timestamp_median_delta=round(ts_median_delta, 3) if ts_median_delta is not None else None,
|
||||
word_timestamps=word_timestamps,
|
||||
last_response=last_response_dict,
|
||||
)
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"RESULT: {audio_file}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Transcription: {transcription}")
|
||||
print(f"Lines: {n_lines} | Responses: {response_count}")
|
||||
print(f"Audio: {audio_duration:.2f}s | Time: {total_elapsed:.2f}s | RTF: {result.rtf:.2f}x")
|
||||
|
||||
if wer_val is not None:
|
||||
print(f"WER: {wer_val:.2%} (S={wer_details['substitutions']} I={wer_details['insertions']} D={wer_details['deletions']})")
|
||||
|
||||
# Print word timestamps if available
|
||||
if word_timestamps:
|
||||
print(f"\nWord timestamps ({len(word_timestamps)} words):")
|
||||
for wt in word_timestamps:
|
||||
print(f" [{wt.start:6.2f} - {wt.end:6.2f}] {wt.word}")
|
||||
|
||||
# Detailed comparison with ground truth
|
||||
if gt:
|
||||
print(f"\n vs Ground truth ({len(gt)} words):")
|
||||
max_words = max(len(word_timestamps), len(gt))
|
||||
for i in range(max_words):
|
||||
pred = word_timestamps[i] if i < len(word_timestamps) else None
|
||||
ref = gt[i] if i < len(gt) else None
|
||||
p_str = f"[{pred.start:5.2f}-{pred.end:5.2f}] {pred.word:<15}" if pred else " " * 30
|
||||
r_str = f"[{ref['start']:5.2f}-{ref['end']:5.2f}] {ref['word']:<15}" if ref else ""
|
||||
delta = ""
|
||||
if pred and ref:
|
||||
d = pred.start - ref['start']
|
||||
delta = f" Δstart={d:+.2f}"
|
||||
print(f" {p_str} | {r_str}{delta}")
|
||||
|
||||
if ts_mae is not None:
|
||||
print(f"\n Timestamp stats: MAE={ts_mae:.3f}s max|Δ|={ts_max_delta:.3f}s median|Δ|={ts_median_delta:.3f}s")
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_audio_files(directory: str) -> List[Path]:
|
||||
"""Find all supported audio files in directory."""
|
||||
d = Path(directory)
|
||||
files = sorted(
|
||||
p for p in d.iterdir()
|
||||
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
|
||||
)
|
||||
return files
|
||||
|
||||
|
||||
async def run_all_tests(
|
||||
engine, audio_files: List[Path], chunk_ms: int, realtime: bool,
|
||||
backend: str, policy: str, lan: str, max_duration: float = 60.0,
|
||||
silence_insertions: Optional[List[List[float]]] = None,
|
||||
) -> List[TestResult]:
|
||||
"""Run tests on multiple audio files sequentially."""
|
||||
results = []
|
||||
for audio_path in audio_files:
|
||||
# Detect language from filename if "french" in name
|
||||
file_lan = lan
|
||||
if "french" in audio_path.name.lower() and lan == "en":
|
||||
file_lan = "fr"
|
||||
logger.info(f"Auto-detected language 'fr' from filename")
|
||||
|
||||
audio = load_audio(str(audio_path))
|
||||
|
||||
# Insert silence segments (applied in reverse position order to keep offsets valid)
|
||||
if silence_insertions:
|
||||
for secs, at_sec in sorted(silence_insertions, key=lambda x: x[1], reverse=True):
|
||||
logger.info(f"Inserting {secs:.1f}s silence at {at_sec:.1f}s")
|
||||
audio = insert_silence(audio, secs, at_sec)
|
||||
|
||||
duration = len(audio) / SAMPLE_RATE
|
||||
|
||||
if duration > max_duration:
|
||||
logger.info(f"Skipping {audio_path.name} ({duration:.0f}s > {max_duration:.0f}s max)")
|
||||
continue
|
||||
|
||||
print(f"\n{'#' * 60}")
|
||||
print(f"# Testing: {audio_path.name} ({duration:.1f}s)")
|
||||
print(f"{'#' * 60}")
|
||||
|
||||
result = await run_test(
|
||||
engine, audio, chunk_ms, realtime,
|
||||
audio_file=audio_path.name, backend=backend, policy=policy, lan=file_lan,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_benchmark_summary(results: List[TestResult]):
|
||||
"""Print a tabular summary of all test results."""
|
||||
print(f"\n{'=' * 110}")
|
||||
print("BENCHMARK SUMMARY")
|
||||
print(f"{'=' * 110}")
|
||||
print(
|
||||
f"{'File':<40} {'Duration':>8} {'Time':>8} {'RTF':>6} "
|
||||
f"{'WER':>7} {'MAE(s)':>7} {'Lines':>5}"
|
||||
)
|
||||
print(f"{'-' * 110}")
|
||||
for r in results:
|
||||
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
|
||||
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
|
||||
print(
|
||||
f"{r.audio_file:<40} {r.audio_duration_s:>7.1f}s {r.processing_time_s:>7.1f}s "
|
||||
f"{r.rtf:>5.2f}x {wer_str:>7} {mae_str:>7} {r.n_lines:>5}"
|
||||
)
|
||||
print(f"{'-' * 110}")
|
||||
total_audio = sum(r.audio_duration_s for r in results)
|
||||
total_time = sum(r.processing_time_s for r in results)
|
||||
avg_rtf = total_time / total_audio if total_audio > 0 else 0
|
||||
wer_vals = [r.wer for r in results if r.wer is not None]
|
||||
avg_wer_str = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
|
||||
mae_vals = [r.timestamp_mae for r in results if r.timestamp_mae is not None]
|
||||
avg_mae_str = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
|
||||
print(
|
||||
f"{'TOTAL/AVG':<40} {total_audio:>7.1f}s {total_time:>7.1f}s "
|
||||
f"{avg_rtf:>5.2f}x {avg_wer_str:>7} {avg_mae_str:>7}"
|
||||
)
|
||||
print(f"{'=' * 110}")
|
||||
|
||||
# Print transcription excerpts
|
||||
print(f"\nTRANSCRIPTIONS:")
|
||||
print(f"{'-' * 110}")
|
||||
for r in results:
|
||||
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
|
||||
print(f" {r.audio_file}:")
|
||||
print(f" {excerpt}")
|
||||
print(f"{'=' * 110}")
|
||||
|
||||
|
||||
def detect_available_backends() -> List[dict]:
|
||||
"""Probe which backends can be imported and return (backend, policy) combos.
|
||||
|
||||
Returns list of dicts with keys: backend, policy, description.
|
||||
"""
|
||||
combos = []
|
||||
|
||||
# faster-whisper
|
||||
try:
|
||||
import faster_whisper # noqa: F401
|
||||
combos.append({"backend": "faster-whisper", "policy": "localagreement", "description": "faster-whisper + LocalAgreement"})
|
||||
combos.append({"backend": "faster-whisper", "policy": "simulstreaming", "description": "faster-whisper + SimulStreaming"})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# mlx-whisper (macOS only)
|
||||
try:
|
||||
import mlx_whisper # noqa: F401
|
||||
combos.append({"backend": "mlx-whisper", "policy": "localagreement", "description": "mlx-whisper + LocalAgreement"})
|
||||
combos.append({"backend": "mlx-whisper", "policy": "simulstreaming", "description": "mlx-whisper + SimulStreaming"})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# openai-whisper
|
||||
try:
|
||||
import whisper # noqa: F401
|
||||
combos.append({"backend": "whisper", "policy": "localagreement", "description": "openai-whisper + LocalAgreement"})
|
||||
combos.append({"backend": "whisper", "policy": "simulstreaming", "description": "openai-whisper + SimulStreaming"})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# voxtral-mlx
|
||||
try:
|
||||
from whisperlivekit.voxtral_mlx import VoxtralMLXModel # noqa: F401
|
||||
combos.append({"backend": "voxtral-mlx", "policy": "voxtral", "description": "voxtral-mlx (MLX)"})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# voxtral (HuggingFace)
|
||||
try:
|
||||
from transformers import AutoModelForSpeechSeq2Seq # noqa: F401
|
||||
combos.append({"backend": "voxtral", "policy": "voxtral", "description": "voxtral (HuggingFace)"})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return combos
|
||||
|
||||
|
||||
def print_cross_backend_comparison(all_results: List[TestResult]):
|
||||
"""Print a comparison table across backends and policies."""
|
||||
print(f"\n{'=' * 110}")
|
||||
print("CROSS-BACKEND BENCHMARK COMPARISON")
|
||||
print(f"{'=' * 110}")
|
||||
print(
|
||||
f"{'Backend':<18} {'Policy':<16} {'File':<30} "
|
||||
f"{'WER':>7} {'RTF':>6} {'MAE(s)':>7} {'MaxΔ(s)':>8}"
|
||||
)
|
||||
print(f"{'-' * 110}")
|
||||
|
||||
for r in all_results:
|
||||
wer_str = f"{r.wer:.2%}" if r.wer is not None else " -"
|
||||
rtf_str = f"{r.rtf:.2f}x"
|
||||
mae_str = f"{r.timestamp_mae:.3f}" if r.timestamp_mae is not None else " -"
|
||||
max_str = f"{r.timestamp_max_delta:.3f}" if r.timestamp_max_delta is not None else " -"
|
||||
# Truncate filename for readability
|
||||
fname = r.audio_file[:28] + ".." if len(r.audio_file) > 30 else r.audio_file
|
||||
print(
|
||||
f"{r.backend:<18} {r.policy:<16} {fname:<30} "
|
||||
f"{wer_str:>7} {rtf_str:>6} {mae_str:>7} {max_str:>8}"
|
||||
)
|
||||
|
||||
print(f"{'-' * 110}")
|
||||
|
||||
# Per-backend averages
|
||||
from collections import defaultdict
|
||||
by_combo = defaultdict(list)
|
||||
for r in all_results:
|
||||
by_combo[(r.backend, r.policy)].append(r)
|
||||
|
||||
print(f"\n{'Backend':<18} {'Policy':<16} {'Avg WER':>8} {'Avg RTF':>8} {'Avg MAE':>8} {'Files':>6}")
|
||||
print(f"{'-' * 80}")
|
||||
for (backend, policy), group in sorted(by_combo.items()):
|
||||
wer_vals = [r.wer for r in group if r.wer is not None]
|
||||
rtf_vals = [r.rtf for r in group]
|
||||
mae_vals = [r.timestamp_mae for r in group if r.timestamp_mae is not None]
|
||||
avg_wer = f"{sum(wer_vals)/len(wer_vals):.2%}" if wer_vals else " -"
|
||||
avg_rtf = f"{sum(rtf_vals)/len(rtf_vals):.2f}x"
|
||||
avg_mae = f"{sum(mae_vals)/len(mae_vals):.3f}" if mae_vals else " -"
|
||||
print(
|
||||
f"{backend:<18} {policy:<16} {avg_wer:>8} {avg_rtf:>8} {avg_mae:>8} {len(group):>6}"
|
||||
)
|
||||
print(f"{'=' * 110}")
|
||||
|
||||
|
||||
def _quiet_loggers(verbose: bool):
|
||||
"""Set internal module log levels to reduce noise."""
|
||||
if verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
else:
|
||||
for mod in (
|
||||
"whisperlivekit.audio_processor", "whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.tokens_alignment", "whisperlivekit.simul_whisper.align_att_base",
|
||||
"whisperlivekit.simul_whisper.simul_whisper",
|
||||
):
|
||||
logging.getLogger(mod).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
async def run_benchmark(
|
||||
audio_files: List[Path], chunk_ms: int, realtime: bool,
|
||||
model_size: str, lan: str, max_duration: float, vac: bool,
|
||||
verbose: bool,
|
||||
) -> List[TestResult]:
|
||||
"""Run benchmark across all available backend+policy combinations."""
|
||||
combos = detect_available_backends()
|
||||
if not combos:
|
||||
logger.error("No backends available. Install at least one ASR backend.")
|
||||
return []
|
||||
|
||||
logger.info(f"Detected {len(combos)} backend+policy combinations:")
|
||||
for c in combos:
|
||||
logger.info(f" - {c['description']}")
|
||||
|
||||
all_results = []
|
||||
for i, combo in enumerate(combos, 1):
|
||||
backend = combo["backend"]
|
||||
policy = combo["policy"]
|
||||
desc = combo["description"]
|
||||
|
||||
print(f"\n{'*' * 70}")
|
||||
print(f"* BENCHMARK {i}/{len(combos)}: {desc}")
|
||||
print(f"{'*' * 70}")
|
||||
|
||||
try:
|
||||
engine = create_engine(
|
||||
backend, model_size, lan, vac=vac, policy=policy,
|
||||
)
|
||||
_quiet_loggers(verbose)
|
||||
|
||||
results = await run_all_tests(
|
||||
engine, audio_files, chunk_ms, realtime,
|
||||
backend=backend, policy=policy, lan=lan,
|
||||
max_duration=max_duration,
|
||||
)
|
||||
all_results.extend(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run {desc}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Offline backend test harness (AudioProcessor-level)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", default="faster-whisper",
|
||||
help="Backend: voxtral, voxtral-mlx, auto, faster-whisper, mlx-whisper, whisper.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy", default="",
|
||||
help="Override backend policy: localagreement, simulstreaming, voxtral.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio", default=None,
|
||||
help="Path to a single audio file (WAV, MP3, FLAC, etc.).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-dir", default=None,
|
||||
help="Directory of audio files to test. Defaults to audio_tests/ if neither --audio nor --audio-dir given.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-ms", type=int, default=100,
|
||||
help="Chunk size in milliseconds (simulates real-time interval).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="", dest="model_size",
|
||||
help="Model size or HF repo ID.",
|
||||
)
|
||||
parser.add_argument("--lan", default="en", help="Language code.")
|
||||
parser.add_argument(
|
||||
"--no-realtime", action="store_true",
|
||||
help="Skip real-time pacing between chunks (faster but less realistic).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac", action="store_true",
|
||||
help="Disable Voice Activity Classification (send all audio without silence filtering).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization", action="store_true",
|
||||
help="Enable speaker diarization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="",
|
||||
choices=["diart", "sortformer"],
|
||||
help="Diarization backend when --diarization is enabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark", action="store_true",
|
||||
help="Run benchmark across all detected backend+policy combinations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json", default=None, dest="json_output",
|
||||
help="Write structured JSON results to this file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-duration", type=float, default=60.0,
|
||||
help="Skip audio files longer than this many seconds (default: 60).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--insert-silence", nargs=2, type=float, metavar=("SECS", "AT_SEC"),
|
||||
action="append", default=[],
|
||||
help="Insert SECS of silence at AT_SEC position. Can be repeated. "
|
||||
"E.g.: --insert-silence 3.0 2.0 --insert-silence 5.0 7.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true",
|
||||
help="Show debug-level logs from all components.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
realtime = not args.no_realtime
|
||||
vac = not args.no_vac
|
||||
|
||||
# Resolve audio file(s)
|
||||
if args.audio:
|
||||
audio_files = [Path(args.audio)]
|
||||
elif args.audio_dir:
|
||||
audio_files = discover_audio_files(args.audio_dir)
|
||||
elif AUDIO_TESTS_DIR.is_dir():
|
||||
audio_files = discover_audio_files(str(AUDIO_TESTS_DIR))
|
||||
else:
|
||||
# Fall back to jfk.wav download
|
||||
audio_files = [download_sample_audio()]
|
||||
|
||||
if not audio_files:
|
||||
logger.error("No audio files found.")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Audio files: {[f.name for f in audio_files]}")
|
||||
|
||||
if args.benchmark:
|
||||
# --- Multi-backend benchmark mode ---
|
||||
all_results = asyncio.run(
|
||||
run_benchmark(
|
||||
audio_files, args.chunk_ms, realtime,
|
||||
args.model_size, args.lan, args.max_duration, vac,
|
||||
args.verbose,
|
||||
)
|
||||
)
|
||||
if all_results:
|
||||
print_cross_backend_comparison(all_results)
|
||||
results = all_results
|
||||
else:
|
||||
# --- Single-backend mode ---
|
||||
policy = args.policy
|
||||
logger.info(f"Creating {args.backend} engine...")
|
||||
engine = create_engine(
|
||||
args.backend, args.model_size, args.lan,
|
||||
diarization=args.diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
vac=vac,
|
||||
policy=policy,
|
||||
)
|
||||
logger.info("Engine ready.")
|
||||
|
||||
_quiet_loggers(args.verbose)
|
||||
|
||||
results = asyncio.run(
|
||||
run_all_tests(
|
||||
engine, audio_files, args.chunk_ms, realtime,
|
||||
args.backend, policy, args.lan,
|
||||
max_duration=args.max_duration,
|
||||
silence_insertions=args.insert_silence or None,
|
||||
)
|
||||
)
|
||||
|
||||
if len(results) > 1:
|
||||
print_benchmark_summary(results)
|
||||
|
||||
# JSON output
|
||||
if args.json_output and results:
|
||||
json_results = []
|
||||
for r in results:
|
||||
d = asdict(r)
|
||||
d.pop("last_response", None) # too verbose for summary
|
||||
json_results.append(d)
|
||||
Path(args.json_output).write_text(
|
||||
json.dumps(json_results, indent=2, ensure_ascii=False)
|
||||
)
|
||||
logger.info(f"Results written to {args.json_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
tests/conftest.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Shared pytest fixtures for WhisperLiveKit tests."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Transcript
|
||||
|
||||
|
||||
AUDIO_TESTS_DIR = Path(__file__).parent.parent / "audio_tests"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tokens():
|
||||
"""A short sequence of ASRToken objects."""
|
||||
return [
|
||||
ASRToken(start=0.0, end=0.5, text="Hello"),
|
||||
ASRToken(start=0.5, end=1.0, text=" world"),
|
||||
ASRToken(start=1.0, end=1.5, text=" test."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_silence():
|
||||
"""A completed silence event."""
|
||||
s = Silence(start=1.5, end=3.0, is_starting=False, has_ended=True)
|
||||
s.compute_duration()
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
"""Minimal args namespace for AudioProcessor tests."""
|
||||
return SimpleNamespace(
|
||||
diarization=False,
|
||||
transcription=True,
|
||||
target_language="",
|
||||
vac=False,
|
||||
vac_chunk_size=0.04,
|
||||
min_chunk_size=0.1,
|
||||
pcm_input=True,
|
||||
punctuation_split=False,
|
||||
backend="faster-whisper",
|
||||
backend_policy="localagreement",
|
||||
vad=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ground_truth_en():
|
||||
"""Ground truth transcript for the 7s English audio (if available)."""
|
||||
path = AUDIO_TESTS_DIR / "00_00_07_english_1_speaker.transcript.json"
|
||||
if path.exists():
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
209
tests/test_audio_processor.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Tests for AudioProcessor pipeline with mocked ASR backends.
|
||||
|
||||
These tests verify the async audio processing pipeline works correctly
|
||||
without requiring any real ASR models to be loaded.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock ASR components
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MockASR:
|
||||
"""Mock ASR model holder."""
|
||||
sep = " "
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self):
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = "en"
|
||||
self.backend_choice = "mock"
|
||||
|
||||
def transcribe(self, audio):
|
||||
return None
|
||||
|
||||
|
||||
class MockOnlineProcessor:
|
||||
"""Mock online processor that returns canned tokens."""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, asr=None):
|
||||
self.asr = asr or MockASR()
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.end = 0.0
|
||||
self._call_count = 0
|
||||
self._finished = False
|
||||
|
||||
def insert_audio_chunk(self, audio, audio_stream_end_time):
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
self.end = audio_stream_end_time
|
||||
|
||||
def process_iter(self, is_last=False):
|
||||
self._call_count += 1
|
||||
# Emit a token on every call when we have audio
|
||||
if len(self.audio_buffer) > 0:
|
||||
t = self._call_count * 0.5
|
||||
return [ASRToken(start=t, end=t + 0.5, text=f"word{self._call_count}")], self.end
|
||||
return [], self.end
|
||||
|
||||
def get_buffer(self):
|
||||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
def start_silence(self):
|
||||
return [], self.end
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
pass
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
self._finished = True
|
||||
return [], self.end
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
pass
|
||||
|
||||
|
||||
def _make_pcm_bytes(duration_s=0.1, sample_rate=16000):
|
||||
"""Generate silent PCM s16le bytes."""
|
||||
n_samples = int(duration_s * sample_rate)
|
||||
audio = np.zeros(n_samples, dtype=np.float32)
|
||||
return (audio * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Create a mock TranscriptionEngine-like object."""
|
||||
engine = SimpleNamespace(
|
||||
asr=MockASR(),
|
||||
diarization_model=None,
|
||||
translation_model=None,
|
||||
args=SimpleNamespace(
|
||||
diarization=False,
|
||||
transcription=True,
|
||||
target_language="",
|
||||
vac=False,
|
||||
vac_chunk_size=0.04,
|
||||
min_chunk_size=0.1,
|
||||
pcm_input=True,
|
||||
punctuation_split=False,
|
||||
backend="mock",
|
||||
backend_policy="localagreement",
|
||||
vad=True,
|
||||
model_size="base",
|
||||
lan="en",
|
||||
),
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPCMConversion:
|
||||
"""Test PCM byte conversion without needing the full pipeline."""
|
||||
|
||||
def test_s16le_roundtrip(self):
|
||||
"""Convert float32 → s16le → float32 and verify approximate roundtrip."""
|
||||
original = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32)
|
||||
s16 = (original * 32768).clip(-32768, 32767).astype(np.int16)
|
||||
pcm_bytes = s16.tobytes()
|
||||
# Direct numpy conversion (same logic as AudioProcessor.convert_pcm_to_float)
|
||||
recovered = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
np.testing.assert_allclose(recovered, original, atol=1 / 32768)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPipelineBasics:
|
||||
async def test_feed_audio_and_get_responses(self, mock_engine):
|
||||
"""Feed audio through the pipeline and verify we get responses."""
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||
results_gen = await processor.create_tasks()
|
||||
|
||||
responses = []
|
||||
|
||||
async def collect():
|
||||
async for resp in results_gen:
|
||||
responses.append(resp)
|
||||
|
||||
task = asyncio.create_task(collect())
|
||||
|
||||
# Feed 2 seconds of audio in 100ms chunks
|
||||
for _ in range(20):
|
||||
await processor.process_audio(_make_pcm_bytes(0.1))
|
||||
|
||||
# Signal EOF
|
||||
await processor.process_audio(None)
|
||||
|
||||
await asyncio.wait_for(task, timeout=10.0)
|
||||
await processor.cleanup()
|
||||
|
||||
# We should have gotten at least one response
|
||||
assert len(responses) > 0
|
||||
|
||||
async def test_eof_terminates_pipeline(self, mock_engine):
|
||||
"""Sending None (EOF) should cleanly terminate the pipeline."""
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||
results_gen = await processor.create_tasks()
|
||||
|
||||
responses = []
|
||||
|
||||
async def collect():
|
||||
async for resp in results_gen:
|
||||
responses.append(resp)
|
||||
|
||||
task = asyncio.create_task(collect())
|
||||
|
||||
# Send a small amount of audio then EOF
|
||||
await processor.process_audio(_make_pcm_bytes(0.5))
|
||||
await processor.process_audio(None)
|
||||
|
||||
await asyncio.wait_for(task, timeout=10.0)
|
||||
await processor.cleanup()
|
||||
|
||||
# Pipeline should have terminated without error
|
||||
assert task.done()
|
||||
|
||||
async def test_empty_audio_no_crash(self, mock_engine):
|
||||
"""Sending EOF immediately (no audio) should not crash."""
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
with patch("whisperlivekit.audio_processor.online_factory", return_value=MockOnlineProcessor()):
|
||||
processor = AudioProcessor(transcription_engine=mock_engine)
|
||||
results_gen = await processor.create_tasks()
|
||||
|
||||
responses = []
|
||||
|
||||
async def collect():
|
||||
async for resp in results_gen:
|
||||
responses.append(resp)
|
||||
|
||||
task = asyncio.create_task(collect())
|
||||
await processor.process_audio(None)
|
||||
|
||||
await asyncio.wait_for(task, timeout=10.0)
|
||||
await processor.cleanup()
|
||||
assert task.done()
|
||||
99
tests/test_config.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for WhisperLiveKitConfig."""
|
||||
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.config import WhisperLiveKitConfig
|
||||
|
||||
|
||||
class TestDefaults:
|
||||
def test_default_backend(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.backend == "auto"
|
||||
|
||||
def test_default_policy(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.backend_policy == "simulstreaming"
|
||||
|
||||
def test_default_language(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.lan == "auto"
|
||||
|
||||
def test_default_vac(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.vac is True
|
||||
|
||||
def test_default_model_size(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.model_size == "base"
|
||||
|
||||
def test_default_transcription(self):
|
||||
c = WhisperLiveKitConfig()
|
||||
assert c.transcription is True
|
||||
assert c.diarization is False
|
||||
|
||||
|
||||
class TestPostInit:
|
||||
def test_en_model_forces_english(self):
|
||||
c = WhisperLiveKitConfig(model_size="tiny.en")
|
||||
assert c.lan == "en"
|
||||
|
||||
def test_en_suffix_with_auto_language(self):
|
||||
c = WhisperLiveKitConfig(model_size="base.en", lan="auto")
|
||||
assert c.lan == "en"
|
||||
|
||||
def test_non_en_model_keeps_language(self):
|
||||
c = WhisperLiveKitConfig(model_size="base", lan="fr")
|
||||
assert c.lan == "fr"
|
||||
|
||||
def test_policy_alias_1(self):
|
||||
c = WhisperLiveKitConfig(backend_policy="1")
|
||||
assert c.backend_policy == "simulstreaming"
|
||||
|
||||
def test_policy_alias_2(self):
|
||||
c = WhisperLiveKitConfig(backend_policy="2")
|
||||
assert c.backend_policy == "localagreement"
|
||||
|
||||
def test_policy_no_alias(self):
|
||||
c = WhisperLiveKitConfig(backend_policy="localagreement")
|
||||
assert c.backend_policy == "localagreement"
|
||||
|
||||
|
||||
class TestFromNamespace:
|
||||
def test_known_keys(self):
|
||||
ns = SimpleNamespace(backend="faster-whisper", lan="en", model_size="large-v3")
|
||||
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||
assert c.backend == "faster-whisper"
|
||||
assert c.lan == "en"
|
||||
assert c.model_size == "large-v3"
|
||||
|
||||
def test_ignores_unknown_keys(self):
|
||||
ns = SimpleNamespace(backend="auto", unknown_key="value", another="x")
|
||||
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||
assert c.backend == "auto"
|
||||
assert not hasattr(c, "unknown_key")
|
||||
|
||||
def test_preserves_defaults_for_missing(self):
|
||||
ns = SimpleNamespace(backend="voxtral-mlx")
|
||||
c = WhisperLiveKitConfig.from_namespace(ns)
|
||||
assert c.lan == "auto"
|
||||
assert c.vac is True
|
||||
|
||||
|
||||
class TestFromKwargs:
|
||||
def test_known_keys(self):
|
||||
c = WhisperLiveKitConfig.from_kwargs(backend="mlx-whisper", lan="fr")
|
||||
assert c.backend == "mlx-whisper"
|
||||
assert c.lan == "fr"
|
||||
|
||||
def test_warns_on_unknown_keys(self, caplog):
|
||||
with caplog.at_level(logging.WARNING, logger="whisperlivekit.config"):
|
||||
c = WhisperLiveKitConfig.from_kwargs(backend="auto", bogus="value")
|
||||
assert c.backend == "auto"
|
||||
assert "bogus" in caplog.text
|
||||
|
||||
def test_post_init_runs(self):
|
||||
c = WhisperLiveKitConfig.from_kwargs(model_size="small.en")
|
||||
assert c.lan == "en"
|
||||
172
tests/test_hypothesis_buffer.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for HypothesisBuffer — the core of LocalAgreement policy."""
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.local_agreement.online_asr import HypothesisBuffer
|
||||
|
||||
|
||||
def make_tokens(words, start=0.0, step=0.5):
|
||||
"""Helper: create ASRToken list from word strings."""
|
||||
tokens = []
|
||||
t = start
|
||||
for w in words:
|
||||
tokens.append(ASRToken(start=t, end=t + step, text=w, probability=0.9))
|
||||
t += step
|
||||
return tokens
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_basic_insert(self):
|
||||
buf = HypothesisBuffer()
|
||||
tokens = make_tokens(["hello", "world"])
|
||||
buf.insert(tokens, offset=0.0)
|
||||
assert len(buf.new) == 2
|
||||
assert buf.new[0].text == "hello"
|
||||
|
||||
def test_insert_with_offset(self):
|
||||
buf = HypothesisBuffer()
|
||||
tokens = make_tokens(["hello"], start=0.0)
|
||||
buf.insert(tokens, offset=5.0)
|
||||
assert buf.new[0].start == pytest.approx(5.0)
|
||||
|
||||
def test_insert_filters_old_tokens(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.last_committed_time = 10.0
|
||||
tokens = make_tokens(["old", "new"], start=5.0, step=3.0)
|
||||
buf.insert(tokens, offset=0.0)
|
||||
# "old" at 5.0 is before last_committed_time - 0.1 = 9.9 → filtered
|
||||
# "new" at 8.0 is also before 9.9 → filtered
|
||||
assert len(buf.new) == 0
|
||||
|
||||
def test_insert_deduplicates_committed(self):
|
||||
buf = HypothesisBuffer()
|
||||
# Commit "hello"
|
||||
tokens1 = make_tokens(["hello", "world"])
|
||||
buf.insert(tokens1, offset=0.0)
|
||||
buf.flush() # commits "hello" (buffer was empty, so nothing matches)
|
||||
# Actually with empty buffer, flush won't commit anything
|
||||
# Let's do it properly: two rounds
|
||||
buf2 = HypothesisBuffer()
|
||||
first = make_tokens(["hello", "world"])
|
||||
buf2.insert(first, offset=0.0)
|
||||
buf2.flush() # buffer was empty → no commits, buffer = ["hello", "world"]
|
||||
|
||||
second = make_tokens(["hello", "world", "test"])
|
||||
buf2.insert(second, offset=0.0)
|
||||
committed = buf2.flush()
|
||||
# LCP of ["hello", "world"] and ["hello", "world", "test"] = ["hello", "world"]
|
||||
assert len(committed) == 2
|
||||
assert committed[0].text == "hello"
|
||||
assert committed[1].text == "world"
|
||||
|
||||
|
||||
class TestFlush:
|
||||
def test_flush_empty(self):
|
||||
buf = HypothesisBuffer()
|
||||
committed = buf.flush()
|
||||
assert committed == []
|
||||
|
||||
def test_flush_lcp_matching(self):
|
||||
buf = HypothesisBuffer()
|
||||
# Round 1: establish buffer
|
||||
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||
buf.flush() # buffer = ["hello", "world"], committed = []
|
||||
|
||||
# Round 2: same prefix, new suffix
|
||||
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||
committed = buf.flush()
|
||||
assert [t.text for t in committed] == ["hello", "world"]
|
||||
|
||||
def test_flush_no_match(self):
|
||||
buf = HypothesisBuffer()
|
||||
# Round 1
|
||||
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||
buf.flush()
|
||||
|
||||
# Round 2: completely different
|
||||
buf.insert(make_tokens(["foo", "bar"]), offset=0.0)
|
||||
committed = buf.flush()
|
||||
assert committed == []
|
||||
|
||||
def test_flush_partial_match(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||
buf.flush()
|
||||
|
||||
buf.insert(make_tokens(["hello", "earth", "again"]), offset=0.0)
|
||||
committed = buf.flush()
|
||||
assert len(committed) == 1
|
||||
assert committed[0].text == "hello"
|
||||
|
||||
def test_flush_updates_last_committed(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.insert(make_tokens(["hello", "world"]), offset=0.0)
|
||||
buf.flush()
|
||||
|
||||
buf.insert(make_tokens(["hello", "world", "test"]), offset=0.0)
|
||||
buf.flush()
|
||||
assert buf.last_committed_word == "world"
|
||||
assert buf.last_committed_time > 0
|
||||
|
||||
def test_flush_with_confidence_validation(self):
|
||||
buf = HypothesisBuffer(confidence_validation=True)
|
||||
high_conf = [
|
||||
ASRToken(start=0.0, end=0.5, text="sure", probability=0.99),
|
||||
ASRToken(start=0.5, end=1.0, text="maybe", probability=0.5),
|
||||
]
|
||||
buf.insert(high_conf, offset=0.0)
|
||||
committed = buf.flush()
|
||||
# "sure" has p>0.95 → committed immediately
|
||||
assert len(committed) == 1
|
||||
assert committed[0].text == "sure"
|
||||
|
||||
|
||||
class TestPopCommitted:
|
||||
def test_pop_removes_old(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.committed_in_buffer = make_tokens(["a", "b", "c"], start=0.0, step=1.0)
|
||||
# "a": end=1.0, "b": end=2.0, "c": end=3.0
|
||||
# pop_committed removes tokens with end <= time
|
||||
buf.pop_committed(2.0)
|
||||
# "a" (end=1.0) and "b" (end=2.0) removed, "c" (end=3.0) remains
|
||||
assert len(buf.committed_in_buffer) == 1
|
||||
assert buf.committed_in_buffer[0].text == "c"
|
||||
|
||||
def test_pop_nothing(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.committed_in_buffer = make_tokens(["a", "b"], start=5.0)
|
||||
buf.pop_committed(0.0)
|
||||
assert len(buf.committed_in_buffer) == 2
|
||||
|
||||
def test_pop_all(self):
|
||||
buf = HypothesisBuffer()
|
||||
buf.committed_in_buffer = make_tokens(["a", "b"], start=0.0, step=0.5)
|
||||
buf.pop_committed(100.0)
|
||||
assert len(buf.committed_in_buffer) == 0
|
||||
|
||||
|
||||
class TestStreamingSimulation:
|
||||
"""Multi-round insert/flush simulating real streaming behavior."""
|
||||
|
||||
def test_three_rounds(self):
|
||||
buf = HypothesisBuffer()
|
||||
all_committed = []
|
||||
|
||||
# Round 1: "this is"
|
||||
buf.insert(make_tokens(["this", "is"]), offset=0.0)
|
||||
all_committed.extend(buf.flush())
|
||||
|
||||
# Round 2: "this is a test"
|
||||
buf.insert(make_tokens(["this", "is", "a", "test"]), offset=0.0)
|
||||
all_committed.extend(buf.flush())
|
||||
|
||||
# Round 3: "this is a test today"
|
||||
buf.insert(make_tokens(["this", "is", "a", "test", "today"]), offset=0.0)
|
||||
all_committed.extend(buf.flush())
|
||||
|
||||
words = [t.text for t in all_committed]
|
||||
assert "this" in words
|
||||
assert "is" in words
|
||||
assert "a" in words
|
||||
assert "test" in words
|
||||
183
tests/test_metrics.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for whisperlivekit.metrics — WER, timestamp accuracy, normalization."""
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy, normalize_text
|
||||
|
||||
|
||||
class TestNormalizeText:
|
||||
def test_lowercase(self):
|
||||
assert normalize_text("Hello World") == "hello world"
|
||||
|
||||
def test_strip_punctuation(self):
|
||||
assert normalize_text("Hello, world!") == "hello world"
|
||||
|
||||
def test_collapse_whitespace(self):
|
||||
assert normalize_text(" hello world ") == "hello world"
|
||||
|
||||
def test_keep_hyphens(self):
|
||||
assert normalize_text("real-time") == "real-time"
|
||||
|
||||
def test_keep_apostrophes(self):
|
||||
assert normalize_text("don't") == "don't"
|
||||
|
||||
def test_unicode_normalized(self):
|
||||
# e + combining accent should be same as precomposed
|
||||
assert normalize_text("caf\u0065\u0301") == normalize_text("caf\u00e9")
|
||||
|
||||
def test_empty(self):
|
||||
assert normalize_text("") == ""
|
||||
|
||||
def test_only_punctuation(self):
|
||||
assert normalize_text("...!?") == ""
|
||||
|
||||
|
||||
class TestComputeWER:
|
||||
def test_perfect_match(self):
|
||||
result = compute_wer("hello world", "hello world")
|
||||
assert result["wer"] == 0.0
|
||||
assert result["substitutions"] == 0
|
||||
assert result["insertions"] == 0
|
||||
assert result["deletions"] == 0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
result = compute_wer("Hello World", "hello world")
|
||||
assert result["wer"] == 0.0
|
||||
|
||||
def test_punctuation_ignored(self):
|
||||
result = compute_wer("Hello, world!", "hello world")
|
||||
assert result["wer"] == 0.0
|
||||
|
||||
def test_one_substitution(self):
|
||||
result = compute_wer("hello world", "hello earth")
|
||||
assert result["wer"] == pytest.approx(0.5)
|
||||
assert result["substitutions"] == 1
|
||||
|
||||
def test_one_insertion(self):
|
||||
result = compute_wer("hello world", "hello big world")
|
||||
assert result["wer"] == pytest.approx(0.5)
|
||||
assert result["insertions"] == 1
|
||||
|
||||
def test_one_deletion(self):
|
||||
result = compute_wer("hello big world", "hello world")
|
||||
assert result["wer"] == pytest.approx(1 / 3)
|
||||
assert result["deletions"] == 1
|
||||
|
||||
def test_completely_different(self):
|
||||
result = compute_wer("the cat sat", "a dog ran")
|
||||
assert result["wer"] == pytest.approx(1.0)
|
||||
|
||||
def test_empty_reference(self):
|
||||
result = compute_wer("", "hello")
|
||||
assert result["wer"] == 1.0 # 1 insertion / 0 ref → treated as float(m)
|
||||
assert result["ref_words"] == 0
|
||||
|
||||
def test_empty_hypothesis(self):
|
||||
result = compute_wer("hello world", "")
|
||||
assert result["wer"] == pytest.approx(1.0)
|
||||
assert result["deletions"] == 2
|
||||
|
||||
def test_both_empty(self):
|
||||
result = compute_wer("", "")
|
||||
assert result["wer"] == 0.0
|
||||
|
||||
def test_ref_and_hyp_word_counts(self):
|
||||
result = compute_wer("one two three", "one two three four")
|
||||
assert result["ref_words"] == 3
|
||||
assert result["hyp_words"] == 4
|
||||
|
||||
|
||||
class TestComputeTimestampAccuracy:
|
||||
def test_perfect_match(self):
|
||||
words = [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0},
|
||||
]
|
||||
result = compute_timestamp_accuracy(words, words)
|
||||
assert result["mae_start"] == 0.0
|
||||
assert result["max_delta_start"] == 0.0
|
||||
assert result["n_matched"] == 2
|
||||
|
||||
def test_constant_offset(self):
|
||||
ref = [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0},
|
||||
]
|
||||
pred = [
|
||||
{"word": "hello", "start": 0.1, "end": 0.6},
|
||||
{"word": "world", "start": 0.6, "end": 1.1},
|
||||
]
|
||||
result = compute_timestamp_accuracy(pred, ref)
|
||||
assert result["mae_start"] == pytest.approx(0.1)
|
||||
assert result["max_delta_start"] == pytest.approx(0.1)
|
||||
assert result["n_matched"] == 2
|
||||
|
||||
def test_mismatched_word_counts(self):
|
||||
ref = [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "beautiful", "start": 0.5, "end": 1.0},
|
||||
{"word": "world", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
pred = [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 1.1, "end": 1.6},
|
||||
]
|
||||
result = compute_timestamp_accuracy(pred, ref)
|
||||
assert result["n_matched"] == 2
|
||||
assert result["n_ref"] == 3
|
||||
assert result["n_pred"] == 2
|
||||
|
||||
def test_empty_predicted(self):
|
||||
ref = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
||||
result = compute_timestamp_accuracy([], ref)
|
||||
assert result["mae_start"] is None
|
||||
assert result["n_matched"] == 0
|
||||
|
||||
def test_empty_reference(self):
|
||||
pred = [{"word": "hello", "start": 0.0, "end": 0.5}]
|
||||
result = compute_timestamp_accuracy(pred, [])
|
||||
assert result["mae_start"] is None
|
||||
assert result["n_matched"] == 0
|
||||
|
||||
def test_case_insensitive_matching(self):
|
||||
ref = [{"word": "Hello", "start": 0.0, "end": 0.5}]
|
||||
pred = [{"word": "hello", "start": 0.1, "end": 0.6}]
|
||||
result = compute_timestamp_accuracy(pred, ref)
|
||||
assert result["n_matched"] == 1
|
||||
assert result["mae_start"] == pytest.approx(0.1)
|
||||
|
||||
def test_median_even_count(self):
|
||||
"""Median with even number of matched words should average the two middle values."""
|
||||
ref = [
|
||||
{"word": "a", "start": 0.0, "end": 0.2},
|
||||
{"word": "b", "start": 0.5, "end": 0.7},
|
||||
{"word": "c", "start": 1.0, "end": 1.2},
|
||||
{"word": "d", "start": 1.5, "end": 1.7},
|
||||
]
|
||||
pred = [
|
||||
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
||||
{"word": "b", "start": 0.7, "end": 0.9}, # delta 0.2
|
||||
{"word": "c", "start": 1.3, "end": 1.5}, # delta 0.3
|
||||
{"word": "d", "start": 1.9, "end": 2.1}, # delta 0.4
|
||||
]
|
||||
result = compute_timestamp_accuracy(pred, ref)
|
||||
assert result["n_matched"] == 4
|
||||
# sorted abs deltas: [0.1, 0.2, 0.3, 0.4] -> median = (0.2 + 0.3) / 2 = 0.25
|
||||
assert result["median_delta_start"] == pytest.approx(0.25)
|
||||
|
||||
def test_median_odd_count(self):
|
||||
"""Median with odd number of matched words takes the middle value."""
|
||||
ref = [
|
||||
{"word": "a", "start": 0.0, "end": 0.2},
|
||||
{"word": "b", "start": 0.5, "end": 0.7},
|
||||
{"word": "c", "start": 1.0, "end": 1.2},
|
||||
]
|
||||
pred = [
|
||||
{"word": "a", "start": 0.1, "end": 0.3}, # delta 0.1
|
||||
{"word": "b", "start": 0.8, "end": 1.0}, # delta 0.3
|
||||
{"word": "c", "start": 1.2, "end": 1.4}, # delta 0.2
|
||||
]
|
||||
result = compute_timestamp_accuracy(pred, ref)
|
||||
assert result["n_matched"] == 3
|
||||
# sorted abs deltas: [0.1, 0.2, 0.3] -> median = 0.2
|
||||
assert result["median_delta_start"] == pytest.approx(0.2)
|
||||
99
tests/test_silence_handling.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for silence handling — state machine and double-counting regression."""
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.timed_objects import Silence
|
||||
|
||||
|
||||
class TestSilenceStateMachine:
|
||||
"""Test Silence object state transitions."""
|
||||
|
||||
def test_initial_state(self):
|
||||
s = Silence(start=1.0, is_starting=True)
|
||||
assert s.is_starting is True
|
||||
assert s.has_ended is False
|
||||
assert s.duration is None
|
||||
assert s.end is None
|
||||
|
||||
def test_end_silence(self):
|
||||
s = Silence(start=1.0, is_starting=True)
|
||||
s.end = 3.0
|
||||
s.is_starting = False
|
||||
s.has_ended = True
|
||||
s.compute_duration()
|
||||
assert s.duration == pytest.approx(2.0)
|
||||
|
||||
def test_very_short_silence(self):
|
||||
s = Silence(start=1.0, end=1.01, is_starting=False, has_ended=True)
|
||||
s.compute_duration()
|
||||
assert s.duration == pytest.approx(0.01)
|
||||
|
||||
def test_zero_duration_silence(self):
|
||||
s = Silence(start=5.0, end=5.0)
|
||||
s.compute_duration()
|
||||
assert s.duration == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestSilenceDoubleCounting:
|
||||
"""Regression tests for the silence double-counting bug.
|
||||
|
||||
The bug: _begin_silence and _end_silence both pushed self.current_silence
|
||||
to the queue. Since they were the same Python object, _end_silence's mutation
|
||||
affected the already-queued start event. The consumer processed both as
|
||||
ended silences, doubling the duration.
|
||||
|
||||
Fix: _begin_silence now pushes a separate Silence object for the start event.
|
||||
"""
|
||||
|
||||
def test_start_and_end_are_separate_objects(self):
|
||||
"""Simulate the fix: start event and end event must be different objects."""
|
||||
# Simulate _begin_silence: creates start event as separate object
|
||||
current_silence = Silence(start=1.0, is_starting=True)
|
||||
start_event = Silence(start=1.0, is_starting=True) # separate copy
|
||||
|
||||
# Simulate _end_silence: mutates current_silence
|
||||
current_silence.end = 3.0
|
||||
current_silence.is_starting = False
|
||||
current_silence.has_ended = True
|
||||
current_silence.compute_duration()
|
||||
|
||||
# start_event should NOT be affected by mutations to current_silence
|
||||
assert start_event.is_starting is True
|
||||
assert start_event.has_ended is False
|
||||
assert start_event.end is None
|
||||
|
||||
# current_silence (end event) has the final state
|
||||
assert current_silence.has_ended is True
|
||||
assert current_silence.duration == pytest.approx(2.0)
|
||||
|
||||
def test_single_object_would_cause_double_counting(self):
|
||||
"""Demonstrate the bug: if same object is used for both events."""
|
||||
shared = Silence(start=1.0, is_starting=True)
|
||||
queue = [shared] # start event queued
|
||||
|
||||
# Mutate (simulates _end_silence)
|
||||
shared.end = 3.0
|
||||
shared.is_starting = False
|
||||
shared.has_ended = True
|
||||
shared.compute_duration()
|
||||
queue.append(shared) # end event queued
|
||||
|
||||
# Both queue items point to the SAME mutated object
|
||||
assert queue[0] is queue[1] # same reference
|
||||
assert queue[0].has_ended is True # start event also shows ended!
|
||||
|
||||
# This would cause double-counting: both items have has_ended=True
|
||||
# and duration=2.0, so the consumer adds 2.0 twice = 4.0
|
||||
|
||||
|
||||
class TestConsecutiveSilences:
|
||||
def test_multiple_silences(self):
|
||||
"""Multiple silence periods should have independent durations."""
|
||||
s1 = Silence(start=1.0, end=2.0)
|
||||
s1.compute_duration()
|
||||
s2 = Silence(start=5.0, end=8.0)
|
||||
s2.compute_duration()
|
||||
assert s1.duration == pytest.approx(1.0)
|
||||
assert s2.duration == pytest.approx(3.0)
|
||||
# Total silence should be sum, not accumulated on single object
|
||||
assert s1.duration + s2.duration == pytest.approx(4.0)
|
||||
185
tests/test_timed_objects.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for whisperlivekit.timed_objects data classes."""
|
||||
|
||||
import pytest
|
||||
|
||||
from whisperlivekit.timed_objects import (
|
||||
ASRToken,
|
||||
FrontData,
|
||||
Segment,
|
||||
Silence,
|
||||
TimedText,
|
||||
Transcript,
|
||||
format_time,
|
||||
)
|
||||
|
||||
|
||||
class TestFormatTime:
|
||||
def test_zero(self):
|
||||
assert format_time(0) == "0:00:00"
|
||||
|
||||
def test_one_minute(self):
|
||||
assert format_time(60) == "0:01:00"
|
||||
|
||||
def test_one_hour(self):
|
||||
assert format_time(3600) == "1:00:00"
|
||||
|
||||
def test_fractional_truncated(self):
|
||||
assert format_time(61.9) == "0:01:01"
|
||||
|
||||
|
||||
class TestASRToken:
|
||||
def test_with_offset(self):
|
||||
t = ASRToken(start=1.0, end=2.0, text="hello")
|
||||
shifted = t.with_offset(0.5)
|
||||
assert shifted.start == pytest.approx(1.5)
|
||||
assert shifted.end == pytest.approx(2.5)
|
||||
assert shifted.text == "hello"
|
||||
|
||||
def test_with_offset_preserves_fields(self):
|
||||
t = ASRToken(start=0.0, end=1.0, text="hi", speaker=2, probability=0.95)
|
||||
shifted = t.with_offset(1.0)
|
||||
assert shifted.speaker == 2
|
||||
assert shifted.probability == 0.95
|
||||
|
||||
def test_is_silence_false(self):
|
||||
t = ASRToken(start=0.0, end=1.0, text="hello")
|
||||
assert t.is_silence() is False
|
||||
|
||||
def test_bool_truthy(self):
|
||||
t = ASRToken(start=0.0, end=1.0, text="hello")
|
||||
assert bool(t) is True
|
||||
|
||||
def test_bool_falsy(self):
|
||||
t = ASRToken(start=0.0, end=1.0, text="")
|
||||
assert bool(t) is False
|
||||
|
||||
|
||||
class TestTimedText:
|
||||
def test_has_punctuation_period(self):
|
||||
t = TimedText(text="hello.")
|
||||
assert t.has_punctuation() is True
|
||||
|
||||
def test_has_punctuation_exclamation(self):
|
||||
t = TimedText(text="wow!")
|
||||
assert t.has_punctuation() is True
|
||||
|
||||
def test_has_punctuation_question(self):
|
||||
t = TimedText(text="really?")
|
||||
assert t.has_punctuation() is True
|
||||
|
||||
def test_has_punctuation_cjk(self):
|
||||
t = TimedText(text="hello。")
|
||||
assert t.has_punctuation() is True
|
||||
|
||||
def test_no_punctuation(self):
|
||||
t = TimedText(text="hello world")
|
||||
assert t.has_punctuation() is False
|
||||
|
||||
def test_duration(self):
|
||||
t = TimedText(start=1.0, end=3.5)
|
||||
assert t.duration() == pytest.approx(2.5)
|
||||
|
||||
def test_contains_timespan(self):
|
||||
outer = TimedText(start=0.0, end=5.0)
|
||||
inner = TimedText(start=1.0, end=3.0)
|
||||
assert outer.contains_timespan(inner) is True
|
||||
assert inner.contains_timespan(outer) is False
|
||||
|
||||
|
||||
class TestSilence:
|
||||
def test_compute_duration(self):
|
||||
s = Silence(start=1.0, end=3.5)
|
||||
d = s.compute_duration()
|
||||
assert d == pytest.approx(2.5)
|
||||
assert s.duration == pytest.approx(2.5)
|
||||
|
||||
def test_compute_duration_none_start(self):
|
||||
s = Silence(start=None, end=3.5)
|
||||
d = s.compute_duration()
|
||||
assert d is None
|
||||
|
||||
def test_compute_duration_none_end(self):
|
||||
s = Silence(start=1.0, end=None)
|
||||
d = s.compute_duration()
|
||||
assert d is None
|
||||
|
||||
def test_is_silence_true(self):
|
||||
s = Silence()
|
||||
assert s.is_silence() is True
|
||||
|
||||
|
||||
class TestTranscript:
|
||||
def test_from_tokens(self, sample_tokens):
|
||||
t = Transcript.from_tokens(sample_tokens, sep="")
|
||||
assert t.text == "Hello world test."
|
||||
assert t.start == pytest.approx(0.0)
|
||||
assert t.end == pytest.approx(1.5)
|
||||
|
||||
def test_from_tokens_with_sep(self, sample_tokens):
|
||||
t = Transcript.from_tokens(sample_tokens, sep="|")
|
||||
assert t.text == "Hello| world| test."
|
||||
|
||||
def test_from_empty_tokens(self):
|
||||
t = Transcript.from_tokens([])
|
||||
assert t.text == ""
|
||||
assert t.start is None
|
||||
assert t.end is None
|
||||
|
||||
def test_from_tokens_with_offset(self, sample_tokens):
|
||||
t = Transcript.from_tokens(sample_tokens, offset=10.0)
|
||||
assert t.start == pytest.approx(10.0)
|
||||
assert t.end == pytest.approx(11.5)
|
||||
|
||||
|
||||
class TestSegment:
|
||||
def test_from_tokens(self, sample_tokens):
|
||||
seg = Segment.from_tokens(sample_tokens)
|
||||
assert seg is not None
|
||||
assert seg.text == "Hello world test."
|
||||
assert seg.start == pytest.approx(0.0)
|
||||
assert seg.end == pytest.approx(1.5)
|
||||
assert seg.speaker == -1
|
||||
|
||||
def test_from_silence_tokens(self):
|
||||
silences = [
|
||||
Silence(start=1.0, end=2.0),
|
||||
Silence(start=2.0, end=3.0),
|
||||
]
|
||||
seg = Segment.from_tokens(silences, is_silence=True)
|
||||
assert seg is not None
|
||||
assert seg.speaker == -2
|
||||
assert seg.is_silence() is True
|
||||
assert seg.text is None
|
||||
|
||||
def test_from_empty_tokens(self):
|
||||
seg = Segment.from_tokens([])
|
||||
assert seg is None
|
||||
|
||||
def test_to_dict(self, sample_tokens):
|
||||
seg = Segment.from_tokens(sample_tokens)
|
||||
d = seg.to_dict()
|
||||
assert "text" in d
|
||||
assert "speaker" in d
|
||||
assert "start" in d
|
||||
assert "end" in d
|
||||
|
||||
|
||||
class TestFrontData:
|
||||
def test_to_dict_empty(self):
|
||||
fd = FrontData()
|
||||
d = fd.to_dict()
|
||||
assert d["lines"] == []
|
||||
assert d["buffer_transcription"] == ""
|
||||
assert "error" not in d
|
||||
|
||||
def test_to_dict_with_error(self):
|
||||
fd = FrontData(error="something broke")
|
||||
d = fd.to_dict()
|
||||
assert d["error"] == "something broke"
|
||||
|
||||
def test_to_dict_with_lines(self, sample_tokens):
|
||||
seg = Segment.from_tokens(sample_tokens)
|
||||
fd = FrontData(lines=[seg])
|
||||
d = fd.to_dict()
|
||||
assert len(d["lines"]) == 1
|
||||
assert d["lines"][0]["text"] == "Hello world test."
|
||||
6575
uv.lock
generated
Normal file
@@ -1,13 +1,13 @@
|
||||
from .download_simulstreaming_backend import download_simulstreaming_backend
|
||||
from .audio_processor import AudioProcessor
|
||||
from .core import TranscriptionEngine
|
||||
from .parse_args import parse_args
|
||||
from .web.web_interface import get_web_interface_html
|
||||
from .web.web_interface import get_inline_ui_html, get_web_interface_html
|
||||
|
||||
__all__ = [
|
||||
"TranscriptionEngine",
|
||||
"AudioProcessor",
|
||||
"parse_args",
|
||||
"get_web_interface_html",
|
||||
"get_inline_ui_html",
|
||||
"download_simulstreaming_backend",
|
||||
]
|
||||
|
||||
47
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def module_available(module_name):
|
||||
"""Return True if the given module can be imported."""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def mlx_backend_available(warn_on_missing = False):
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
available = (
|
||||
is_macos
|
||||
and is_arm
|
||||
and module_available("mlx_whisper")
|
||||
)
|
||||
if not available and warn_on_missing and is_macos and is_arm:
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||
"Consider installing mlx-whisper for better performance: "
|
||||
"`pip install mlx-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
|
||||
|
||||
def voxtral_hf_backend_available():
|
||||
"""Return True if HF Transformers Voxtral backend is available."""
|
||||
return module_available("transformers")
|
||||
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||
"for better performance: `pip install faster-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
@@ -1,25 +1,26 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
||||
get_inline_ui_html, parse_args)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
args = parse_args()
|
||||
config = parse_args()
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
)
|
||||
transcription_engine = TranscriptionEngine(config=config)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -33,21 +34,21 @@ app.add_middleware(
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json(response.to_dict())
|
||||
# when the results_generator finishes it means all audio has been processed
|
||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||
logger.exception(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
@@ -58,6 +59,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
)
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send config to client: {e}")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
@@ -95,24 +101,26 @@ def main():
|
||||
|
||||
uvicorn_kwargs = {
|
||||
"app": "whisperlivekit.basic_server:app",
|
||||
"host":args.host,
|
||||
"port":args.port,
|
||||
"host": config.host,
|
||||
"port": config.port,
|
||||
"reload": False,
|
||||
"log_level": "info",
|
||||
"lifespan": "on",
|
||||
}
|
||||
|
||||
|
||||
ssl_kwargs = {}
|
||||
if args.ssl_certfile or args.ssl_keyfile:
|
||||
if not (args.ssl_certfile and args.ssl_keyfile):
|
||||
if config.ssl_certfile or config.ssl_keyfile:
|
||||
if not (config.ssl_certfile and config.ssl_keyfile):
|
||||
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
||||
ssl_kwargs = {
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile
|
||||
"ssl_certfile": config.ssl_certfile,
|
||||
"ssl_keyfile": config.ssl_keyfile,
|
||||
}
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
if config.forwarded_allow_ips:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips}
|
||||
|
||||
uvicorn.run(**uvicorn_kwargs)
|
||||
|
||||
|
||||
102
whisperlivekit/config.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Typed configuration for the WhisperLiveKit pipeline."""
|
||||
import logging
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperLiveKitConfig:
|
||||
"""Single source of truth for all WhisperLiveKit configuration.
|
||||
|
||||
Replaces the previous dict-based parameter system in TranscriptionEngine.
|
||||
All fields have defaults matching the prior behaviour.
|
||||
"""
|
||||
|
||||
# Server / global
|
||||
host: str = "localhost"
|
||||
port: int = 8000
|
||||
diarization: bool = False
|
||||
punctuation_split: bool = False
|
||||
target_language: str = ""
|
||||
vac: bool = True
|
||||
vac_chunk_size: float = 0.04
|
||||
log_level: str = "DEBUG"
|
||||
ssl_certfile: Optional[str] = None
|
||||
ssl_keyfile: Optional[str] = None
|
||||
forwarded_allow_ips: Optional[str] = None
|
||||
transcription: bool = True
|
||||
vad: bool = True
|
||||
pcm_input: bool = False
|
||||
disable_punctuation_split: bool = False
|
||||
diarization_backend: str = "sortformer"
|
||||
backend_policy: str = "simulstreaming"
|
||||
backend: str = "auto"
|
||||
|
||||
# Transcription common
|
||||
warmup_file: Optional[str] = None
|
||||
min_chunk_size: float = 0.1
|
||||
model_size: str = "base"
|
||||
model_cache_dir: Optional[str] = None
|
||||
model_dir: Optional[str] = None
|
||||
model_path: Optional[str] = None
|
||||
lora_path: Optional[str] = None
|
||||
lan: str = "auto"
|
||||
direct_english_translation: bool = False
|
||||
|
||||
# LocalAgreement-specific
|
||||
buffer_trimming: str = "segment"
|
||||
confidence_validation: bool = False
|
||||
buffer_trimming_sec: float = 15.0
|
||||
|
||||
# SimulStreaming-specific
|
||||
disable_fast_encoder: bool = False
|
||||
custom_alignment_heads: Optional[str] = None
|
||||
frame_threshold: int = 25
|
||||
beams: int = 1
|
||||
decoder_type: Optional[str] = None
|
||||
audio_max_len: float = 20.0
|
||||
audio_min_len: float = 0.0
|
||||
cif_ckpt_path: Optional[str] = None
|
||||
never_fire: bool = False
|
||||
init_prompt: Optional[str] = None
|
||||
static_init_prompt: Optional[str] = None
|
||||
max_context_tokens: Optional[int] = None
|
||||
|
||||
# Diarization (diart)
|
||||
segmentation_model: str = "pyannote/segmentation-3.0"
|
||||
embedding_model: str = "pyannote/embedding"
|
||||
|
||||
# Translation
|
||||
nllb_backend: str = "transformers"
|
||||
nllb_size: str = "600M"
|
||||
|
||||
def __post_init__(self):
|
||||
# .en model suffix forces English
|
||||
if self.model_size and self.model_size.endswith(".en"):
|
||||
self.lan = "en"
|
||||
# Normalize backend_policy aliases
|
||||
if self.backend_policy == "1":
|
||||
self.backend_policy = "simulstreaming"
|
||||
elif self.backend_policy == "2":
|
||||
self.backend_policy = "localagreement"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_namespace(cls, ns) -> "WhisperLiveKitConfig":
|
||||
"""Create config from an argparse Namespace, ignoring unknown keys."""
|
||||
known = {f.name for f in fields(cls)}
|
||||
return cls(**{k: v for k, v in vars(ns).items() if k in known})
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "WhisperLiveKitConfig":
|
||||
"""Create config from keyword arguments; warns on unknown keys."""
|
||||
known = {f.name for f in fields(cls)}
|
||||
unknown = set(kwargs.keys()) - known
|
||||
if unknown:
|
||||
logger.warning("Unknown config keys ignored: %s", unknown)
|
||||
return cls(**{k: v for k, v in kwargs.items() if k in known})
|
||||
@@ -1,92 +1,207 @@
|
||||
try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
|
||||
from whisperlivekit.config import WhisperLiveKitConfig
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
_lock = threading.Lock() # Thread-safe singleton lock
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Double-checked locking pattern for thread-safe singleton
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
with cls._lock:
|
||||
# Check again inside lock to prevent race condition
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
def __init__(self, config=None, **kwargs):
|
||||
# Thread-safe initialization check
|
||||
with TranscriptionEngine._lock:
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
defaults = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"confidence_validation": False,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"backend": "faster-whisper",
|
||||
"vac": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"buffer_trimming": "segment",
|
||||
"buffer_trimming_sec": 15,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
# simulstreaming params:
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 30.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
}
|
||||
try:
|
||||
self._do_init(config, **kwargs)
|
||||
except Exception:
|
||||
# Reset singleton so a retry is possible
|
||||
with TranscriptionEngine._lock:
|
||||
TranscriptionEngine._instance = None
|
||||
TranscriptionEngine._initialized = False
|
||||
raise
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
with TranscriptionEngine._lock:
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
def _do_init(self, config=None, **kwargs):
|
||||
# Handle negated kwargs from programmatic API
|
||||
if 'no_transcription' in kwargs:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
kwargs['transcription'] = not kwargs.pop('no_transcription')
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
kwargs['vad'] = not kwargs.pop('no_vad')
|
||||
if 'no_vac' in kwargs:
|
||||
kwargs['vac'] = not kwargs.pop('no_vac')
|
||||
|
||||
if 'language' in kwargs:
|
||||
config_dict['lan'] = kwargs['language']
|
||||
config_dict.pop('language', None)
|
||||
if config is None:
|
||||
if isinstance(kwargs.get('config'), WhisperLiveKitConfig):
|
||||
config = kwargs.pop('config')
|
||||
else:
|
||||
config = WhisperLiveKitConfig.from_kwargs(**kwargs)
|
||||
self.config = config
|
||||
|
||||
# Backward compat: expose as self.args (Namespace-like) for AudioProcessor etc.
|
||||
self.args = Namespace(**asdict(config))
|
||||
|
||||
self.args = Namespace(**config_dict)
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
self.diarization = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.asr, self.tokenizer = backend_factory(self.args)
|
||||
warmup_asr(self.asr, self.args.warmup_file)
|
||||
self.vac_session = None
|
||||
|
||||
if self.args.diarization:
|
||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||
self.diarization = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
|
||||
TranscriptionEngine._initialized = True
|
||||
if config.vac:
|
||||
from whisperlivekit.silero_vad_iterator import is_onnx_available
|
||||
|
||||
if is_onnx_available():
|
||||
from whisperlivekit.silero_vad_iterator import load_onnx_session
|
||||
self.vac_session = load_onnx_session()
|
||||
else:
|
||||
logger.warning(
|
||||
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
|
||||
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
|
||||
)
|
||||
|
||||
transcription_common_params = {
|
||||
"warmup_file": config.warmup_file,
|
||||
"min_chunk_size": config.min_chunk_size,
|
||||
"model_size": config.model_size,
|
||||
"model_cache_dir": config.model_cache_dir,
|
||||
"model_dir": config.model_dir,
|
||||
"model_path": config.model_path,
|
||||
"lora_path": config.lora_path,
|
||||
"lan": config.lan,
|
||||
"direct_english_translation": config.direct_english_translation,
|
||||
}
|
||||
|
||||
if config.transcription:
|
||||
if config.backend == "voxtral-mlx":
|
||||
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXASR
|
||||
self.tokenizer = None
|
||||
self.asr = VoxtralMLXASR(**transcription_common_params)
|
||||
logger.info("Using Voxtral MLX native backend")
|
||||
elif config.backend == "voxtral":
|
||||
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR
|
||||
self.tokenizer = None
|
||||
self.asr = VoxtralHFStreamingASR(**transcription_common_params)
|
||||
logger.info("Using Voxtral HF Transformers streaming backend")
|
||||
elif config.backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": config.disable_fast_encoder,
|
||||
"custom_alignment_heads": config.custom_alignment_heads,
|
||||
"frame_threshold": config.frame_threshold,
|
||||
"beams": config.beams,
|
||||
"decoder_type": config.decoder_type,
|
||||
"audio_max_len": config.audio_max_len,
|
||||
"audio_min_len": config.audio_min_len,
|
||||
"cif_ckpt_path": config.cif_ckpt_path,
|
||||
"never_fire": config.never_fire,
|
||||
"init_prompt": config.init_prompt,
|
||||
"static_init_prompt": config.static_init_prompt,
|
||||
"max_context_tokens": config.max_context_tokens,
|
||||
}
|
||||
|
||||
self.tokenizer = None
|
||||
self.asr = SimulStreamingASR(
|
||||
**transcription_common_params,
|
||||
**simulstreaming_params,
|
||||
backend=config.backend,
|
||||
)
|
||||
logger.info(
|
||||
"Using SimulStreaming policy with %s backend",
|
||||
getattr(self.asr, "encoder_backend", "whisper"),
|
||||
)
|
||||
else:
|
||||
whisperstreaming_params = {
|
||||
"buffer_trimming": config.buffer_trimming,
|
||||
"confidence_validation": config.confidence_validation,
|
||||
"buffer_trimming_sec": config.buffer_trimming_sec,
|
||||
}
|
||||
|
||||
self.asr = backend_factory(
|
||||
backend=config.backend,
|
||||
**transcription_common_params,
|
||||
**whisperstreaming_params,
|
||||
)
|
||||
logger.info(
|
||||
"Using LocalAgreement policy with %s backend",
|
||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||
)
|
||||
|
||||
if config.diarization:
|
||||
if config.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=config.min_chunk_size,
|
||||
segmentation_model=config.segmentation_model,
|
||||
embedding_model=config.embedding_model,
|
||||
)
|
||||
elif config.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
|
||||
self.translation_model = None
|
||||
if config.target_language:
|
||||
if config.lan == 'auto' and config.backend_policy != "simulstreaming":
|
||||
raise ValueError('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
try:
|
||||
from nllw import load_model
|
||||
except ImportError:
|
||||
raise ImportError('To use translation, you must install nllw: `pip install nllw`')
|
||||
self.translation_model = load_model(
|
||||
[config.lan],
|
||||
nllb_backend=config.nllb_backend,
|
||||
nllb_size=config.nllb_size,
|
||||
)
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if getattr(args, 'backend', None) == "voxtral-mlx":
|
||||
from whisperlivekit.voxtral_mlx_asr import VoxtralMLXOnlineProcessor
|
||||
return VoxtralMLXOnlineProcessor(asr)
|
||||
if getattr(args, 'backend', None) == "voxtral":
|
||||
from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor
|
||||
return VoxtralHFStreamingOnlineProcessor(asr)
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
return SimulStreamingOnlineProcessor(asr)
|
||||
return OnlineASRProcessor(asr)
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
if args.diarization_backend == "diart":
|
||||
online = diarization_backend
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||
elif args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import \
|
||||
SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {args.diarization_backend}")
|
||||
return online
|
||||
|
||||
|
||||
def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from nllw import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
|
||||
@@ -1,34 +1,31 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
import numpy as np
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import SimpleQueue, Empty
|
||||
from queue import Empty, SimpleQueue
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import diart.models as m
|
||||
import numpy as np
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.inference import StreamingInference
|
||||
from diart.sources import AudioSource
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
from diart.sources import MicrophoneAudioSource
|
||||
from rx.core import Observer
|
||||
from typing import Tuple, Any, List
|
||||
from diart.sources import AudioSource, MicrophoneAudioSource
|
||||
from pyannote.core import Annotation
|
||||
import diart.models as m
|
||||
from rx.core import Observer
|
||||
|
||||
from whisperlivekit.diarization.utils import extract_number
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else None
|
||||
|
||||
class DiarizationObserver(Observer):
|
||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||
|
||||
def __init__(self):
|
||||
self.speaker_segments = []
|
||||
self.diarization_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
def on_next(self, value: Tuple[Annotation, Any]):
|
||||
annotation, audio = value
|
||||
@@ -47,10 +44,10 @@ class DiarizationObserver(Observer):
|
||||
for speaker, label in annotation._labels.items():
|
||||
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
self.diarization_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start,
|
||||
end=end
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
))
|
||||
else:
|
||||
logger.debug("\nNo speakers detected in this segment")
|
||||
@@ -58,14 +55,14 @@ class DiarizationObserver(Observer):
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
self.diarization_segments = [
|
||||
segment for segment in self.diarization_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
|
||||
@@ -165,7 +162,7 @@ class WebSocketAudioSource(AudioSource):
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||
|
||||
@@ -177,7 +174,6 @@ class DiartDiarization:
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
@@ -199,117 +195,90 @@ class DiartDiarization:
|
||||
self.inference.attach_observers(self.observer)
|
||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
Only used when working with WebSocketAudioSource.
|
||||
"""
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
"""Buffer audio for the next diarization step."""
|
||||
if self.custom_source:
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
self.observer.clear_old_segments()
|
||||
return self.observer.get_segments()
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
|
||||
async def diarize(self):
|
||||
"""Return the current speaker segments from the diarization pipeline."""
|
||||
return self.observer.get_segments()
|
||||
|
||||
def close(self):
|
||||
"""Close the audio source."""
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
|
||||
if use_punctuation_split and len(tokens) > 1:
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
|
||||
print("Here are the tokens:",
|
||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||
|
||||
segment_map = []
|
||||
for segment in segments:
|
||||
speaker_num = extract_number(segment.speaker) + 1
|
||||
segment_map.append((segment.start, segment.end, speaker_num))
|
||||
segment_map.sort(key=lambda x: x[0])
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
current_token = tokens[i]
|
||||
|
||||
is_sentence_end = False
|
||||
if current_token.text and current_token.text.strip():
|
||||
text = current_token.text.strip()
|
||||
if text[-1] in punctuation_marks:
|
||||
is_sentence_end = True
|
||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||
|
||||
if is_sentence_end and current_token.speaker != -1:
|
||||
punctuation_time = current_token.end
|
||||
current_speaker = current_token.speaker
|
||||
|
||||
j = i + 1
|
||||
next_sentence_tokens = []
|
||||
while j < len(tokens):
|
||||
next_token = tokens[j]
|
||||
next_sentence_tokens.append(j)
|
||||
|
||||
# Check if this token ends the next sentence
|
||||
if next_token.text and next_token.text.strip():
|
||||
if next_token.text.strip()[-1] in punctuation_marks:
|
||||
break
|
||||
j += 1
|
||||
|
||||
if next_sentence_tokens:
|
||||
speaker_times = {}
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
token = tokens[idx]
|
||||
# Find which segments overlap with this token
|
||||
for seg_start, seg_end, seg_speaker in segment_map:
|
||||
if not (seg_end <= token.start or seg_start >= token.end):
|
||||
# Calculate overlap duration
|
||||
overlap_start = max(seg_start, token.start)
|
||||
overlap_end = min(seg_end, token.end)
|
||||
overlap_duration = overlap_end - overlap_start
|
||||
|
||||
if seg_speaker not in speaker_times:
|
||||
speaker_times[seg_speaker] = 0
|
||||
speaker_times[seg_speaker] += overlap_duration
|
||||
|
||||
if speaker_times:
|
||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||
|
||||
if dominant_speaker != current_speaker:
|
||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker != dominant_speaker:
|
||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||
tokens[idx].speaker = dominant_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
else:
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker == -1:
|
||||
tokens[idx].speaker = current_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
|
||||
i += 1
|
||||
|
||||
return end_attributed_speaker
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
for segment in segments:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
# print("Segments concatenated:")
|
||||
# for entry in segments_concatenated:
|
||||
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||
return segments_concatenated
|
||||
|
||||
|
||||
def add_speaker_to_tokens(segments, tokens):
|
||||
"""
|
||||
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
# print(
|
||||
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||
# )
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
def visualize_tokens(tokens):
|
||||
conversation = [{"speaker": -1, "text": ""}]
|
||||
for token in tokens:
|
||||
speaker = conversation[-1]['speaker']
|
||||
if token.speaker != speaker:
|
||||
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||
else:
|
||||
conversation[-1]['text'] += token.text
|
||||
print("Conversation:")
|
||||
for entry in conversation:
|
||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||
327
whisperlivekit/diarization/sortformer_backend.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from queue import Empty, SimpleQueue
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
except ImportError:
|
||||
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.spkcache = None # Speaker cache to store embeddings from start
|
||||
self.spkcache_lengths = None
|
||||
self.spkcache_preds = None # speaker cache predictions
|
||||
self.fifo = None # to save the embedding from the latest chunks
|
||||
self.fifo_lengths = None
|
||||
self.fifo_preds = None
|
||||
self.spk_perm = None
|
||||
self.mean_sil_emb = None
|
||||
self.n_sil_frames = None
|
||||
|
||||
|
||||
class SortformerDiarization:
|
||||
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||
"""
|
||||
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||
"""
|
||||
self._load_model(model_name)
|
||||
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and configure the Sortformer model for streaming."""
|
||||
try:
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.diar_model.to(device)
|
||||
|
||||
## to test
|
||||
# for name, param in self.diar_model.named_parameters():
|
||||
# if param.device != device:
|
||||
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||
|
||||
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||
|
||||
self.diar_model.sortformer_modules.chunk_len = 10
|
||||
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||
self.diar_model.sortformer_modules.chunk_right_context = 0
|
||||
self.diar_model.sortformer_modules.chunk_left_context = 10
|
||||
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||
self.diar_model.sortformer_modules.fifo_len = 188
|
||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
self.diar_model.sortformer_modules.log = False
|
||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Sortformer model: {e}")
|
||||
raise
|
||||
|
||||
class SortformerDiarizationOnline:
|
||||
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize the streaming Sortformer diarization system.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.diarization_segments = []
|
||||
self.diar_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
|
||||
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size=0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0
|
||||
)
|
||||
self.audio2mel.to(self.diar_model.device)
|
||||
|
||||
self.chunk_duration_seconds = (
|
||||
self.diar_model.sortformer_modules.chunk_len *
|
||||
self.diar_model.sortformer_modules.subsampling_factor *
|
||||
self.diar_model.preprocessor._cfg.window_stride
|
||||
)
|
||||
|
||||
self._init_streaming_state()
|
||||
|
||||
self._previous_chunk_features = None
|
||||
self._chunk_index = 0
|
||||
self._len_prediction = None
|
||||
|
||||
# Audio buffer to store PCM chunks for debugging
|
||||
self.audio_buffer = []
|
||||
|
||||
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||
self.audio_chunk_buffer = []
|
||||
self.accumulated_duration = 0.0
|
||||
|
||||
logger.info("SortformerDiarization initialized successfully")
|
||||
|
||||
|
||||
def _init_streaming_state(self):
|
||||
"""Initialize the streaming state for the model."""
|
||||
batch_size = 1
|
||||
device = self.diar_model.device
|
||||
|
||||
self.streaming_state = StreamingSortformerState()
|
||||
self.streaming_state.spkcache = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_preds = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.fifo = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: Optional[float]):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
Args:
|
||||
silence_duration: Duration of silence in seconds
|
||||
"""
|
||||
with self.segment_lock:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
|
||||
|
||||
async def diarize(self):
|
||||
"""
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return []
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
new_segments = self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
return new_segments
|
||||
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers) #12
|
||||
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
new_segments = []
|
||||
|
||||
with self.segment_lock:
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
current_spk = current_chunk_preds[0]
|
||||
start_time = round(base_time, 2)
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
current_time = round(base_time + idx * frame_duration, 2)
|
||||
if spk != current_spk:
|
||||
new_segments.append(SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
))
|
||||
start_time = current_time
|
||||
current_spk = spk
|
||||
new_segments.append(
|
||||
SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
)
|
||||
)
|
||||
return new_segments
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def close(self):
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.diarization_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
|
||||
|
||||
from whisperlivekit.diarization.utils import extract_number
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
import librosa
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'diarization_audio.wav'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization_backend = SortformerDiarization()
|
||||
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
new_segments = await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
print(new_segments)
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
for segment in segments:
|
||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||
|
||||
asyncio.run(main())
|
||||
7
whisperlivekit/diarization/utils.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import re
|
||||
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
"""Extract the first integer from a string, e.g. 'speaker_2' -> 2."""
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else 0
|
||||
@@ -1,32 +0,0 @@
|
||||
import os
|
||||
import requests
|
||||
import inspect
|
||||
|
||||
def get_module_path():
|
||||
return os.path.dirname(inspect.getfile(inspect.currentframe()))
|
||||
|
||||
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
|
||||
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
|
||||
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
|
||||
|
||||
def download_files_from_github(api_url, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
response = requests.get(api_url)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
for item in items:
|
||||
if item['type'] == 'file':
|
||||
download_url = item['download_url']
|
||||
file_name = item['name']
|
||||
file_response = requests.get(download_url)
|
||||
file_response.raise_for_status()
|
||||
with open(os.path.join(local_dir, file_name), 'wb') as f:
|
||||
f.write(file_response.content)
|
||||
elif item['type'] == 'dir':
|
||||
# Recursive call for subdirectories
|
||||
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
|
||||
|
||||
def download_simulstreaming_backend():
|
||||
print(f"Downloading files into {TARGET_DIR} ...")
|
||||
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
|
||||
print("✅ Download of SimulStreaming backend files completed successfully.")
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
import contextlib
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ERROR_INSTALL_INSTRUCTIONS = """
|
||||
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||
{'='*50}
|
||||
FFmpeg is not installed or not found in your system's PATH.
|
||||
Please install FFmpeg to enable audio processing.
|
||||
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||
|
||||
Installation instructions:
|
||||
If you want to install FFmpeg:
|
||||
|
||||
# Ubuntu/Debian:
|
||||
sudo apt update && sudo apt install ffmpeg
|
||||
@@ -25,6 +26,7 @@ brew install ffmpeg
|
||||
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||
|
||||
After installation, please restart the application.
|
||||
{'='*50}
|
||||
"""
|
||||
|
||||
class FFmpegState(Enum):
|
||||
@@ -143,7 +145,7 @@ class FFmpegManager:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.process.stdout.read(size),
|
||||
timeout=5.0
|
||||
timeout=20.0
|
||||
)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
@@ -183,6 +185,8 @@ class FFmpegManager:
|
||||
async def _drain_stderr(self):
|
||||
try:
|
||||
while True:
|
||||
if not self.process or not self.process.stderr:
|
||||
break
|
||||
line = await self.process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
@@ -190,4 +194,4 @@ class FFmpegManager:
|
||||
except asyncio.CancelledError:
|
||||
logger.info("FFmpeg stderr drain task cancelled.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
|
||||
284
whisperlivekit/local_agreement/backends.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.lora_path = lora_path
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||
|
||||
def load_model(self, model_size, cache_dir, model_dir):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def use_vad(self):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperASR(ASRBase):
|
||||
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from whisperlivekit.whisper import load_model as load_whisper_model
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
model_info = detect_model_format(resolved_path)
|
||||
if not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
options = dict(self.transcribe_kargs)
|
||||
options.pop("vad", None)
|
||||
options.pop("vad_filter", None)
|
||||
language = self.original_language if self.original_language else None
|
||||
|
||||
result = whisper_transcribe(
|
||||
self.model,
|
||||
audio,
|
||||
language=language,
|
||||
initial_prompt=init_prompt,
|
||||
condition_on_previous_text=True,
|
||||
word_timestamps=True,
|
||||
**options,
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the Whisper result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(
|
||||
word["start"],
|
||||
word["end"],
|
||||
word["word"],
|
||||
probability=word.get("probability"),
|
||||
)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
device = "auto" # Allow CTranslate2 to decide available device
|
||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||
|
||||
|
||||
model = WhisperModel(
|
||||
model_size_or_path,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=cache_dir,
|
||||
)
|
||||
return model
|
||||
|
||||
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
||||
segments, info = self.model.transcribe(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
initial_prompt=init_prompt,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
)
|
||||
return list(segments)
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, segments) -> List[float]:
|
||||
return [segment.end for segment in segments]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MLX Whisper optimized for Apple Silicon.
|
||||
"""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
import mlx.core as mx
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
else:
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
dtype = mx.float16
|
||||
ModelHolder.get_model(model_size_or_path, dtype)
|
||||
return transcribe
|
||||
|
||||
def translate_model_name(self, model_name):
|
||||
from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
|
||||
mlx_model_path = MLX_MODEL_MAPPING.get(model_name)
|
||||
if mlx_model_path:
|
||||
return mlx_model_path
|
||||
else:
|
||||
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
if self.transcribe_kargs:
|
||||
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
||||
segments = self.model(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
initial_prompt=init_prompt,
|
||||
word_timestamps=True,
|
||||
condition_on_previous_text=True,
|
||||
path_or_hf_repo=self.model_size_or_path,
|
||||
)
|
||||
return segments.get("segments", [])
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s["end"] for s in res]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for transcription."""
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.modelname = "whisper-1"
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
self.response_format = "verbose_json"
|
||||
self.temperature = temperature
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.direct_english_translation = False
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
self.transcribed_seconds = 0
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
"""
|
||||
Converts OpenAI API response words into ASRToken objects while
|
||||
optionally skipping words that fall into no-speech segments.
|
||||
"""
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
if segment.no_speech_prob > 0.8:
|
||||
no_speech_segments.append((segment.start, segment.end))
|
||||
tokens = []
|
||||
for word in segments.words:
|
||||
start = word.start
|
||||
end = word.end
|
||||
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||
continue
|
||||
tokens.append(ASRToken(start, end, word.word))
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s.end for s in res.words]
|
||||
|
||||
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||
buffer = io.BytesIO()
|
||||
buffer.name = "temp.wav"
|
||||
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
||||
buffer.seek(0)
|
||||
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
||||
params = {
|
||||
"model": self.modelname,
|
||||
"file": buffer,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word", "segment"],
|
||||
}
|
||||
if not self.direct_english_translation and self.original_language:
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
task = self.transcribe_kargs.get("task", self.task)
|
||||
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
@@ -1,23 +1,13 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import sys
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# simulStreaming imports - we check if the files are here
|
||||
try:
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
OnlineProcessorInterface = None
|
||||
torch = None
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
Buffer to store and process ASR hypothesis tokens.
|
||||
@@ -118,9 +108,6 @@ class OnlineASRProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
@@ -131,12 +118,14 @@ class OnlineASRProcessor:
|
||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.tokenize = asr.tokenizer
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.confidence_validation = asr.confidence_validation
|
||||
self.global_time_offset = 0.0
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
self.buffer_trimming_way = asr.buffer_trimming
|
||||
self.buffer_trimming_sec = asr.buffer_trimming_sec
|
||||
|
||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||
@@ -147,6 +136,11 @@ class OnlineASRProcessor:
|
||||
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||
)
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
"""Handle speaker change event."""
|
||||
self.process_iter()
|
||||
self.init(offset=change_speaker.start)
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing buffers."""
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
@@ -164,6 +158,32 @@ class OnlineASRProcessor:
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def start_silence(self):
|
||||
if self.audio_buffer.size == 0:
|
||||
return [], self.get_audio_buffer_end_time()
|
||||
return self.process_iter()
|
||||
|
||||
def end_silence(self, silence_duration: Optional[float], offset: float):
|
||||
if not silence_duration or silence_duration <= 0:
|
||||
return
|
||||
|
||||
long_silence = silence_duration >= 5
|
||||
if not long_silence:
|
||||
gap_samples = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_samples > 0:
|
||||
gap_silence = np.zeros(gap_samples, dtype=np.float32)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
else:
|
||||
self.init(offset=silence_duration + offset)
|
||||
|
||||
self.global_time_offset += silence_duration
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
Backwards compatibility shim for legacy callers that still use insert_silence.
|
||||
"""
|
||||
self.end_silence(silence_duration, offset)
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context), where:
|
||||
@@ -242,6 +262,9 @@ class OnlineASRProcessor:
|
||||
logger.debug(
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
if self.global_time_offset:
|
||||
for token in committed_tokens:
|
||||
token = token.with_offset(self.global_time_offset)
|
||||
return committed_tokens, current_audio_processed_upto
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
@@ -395,338 +418,11 @@ class OnlineASRProcessor:
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
|
||||
class VACOnlineASRProcessor:
|
||||
"""
|
||||
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
||||
|
||||
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
||||
and when the system detects a pause in speech (or end of an utterance)
|
||||
it finalizes the utterance immediately.
|
||||
"""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||
self.online_chunk_size = online_chunk_size
|
||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||
self.asr = self.online.asr
|
||||
|
||||
# Load a VAD model (e.g. Silero VAD)
|
||||
import torch
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from .silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.logfile = self.online.logfile
|
||||
self.last_input_audio_stream_end_time: float = 0.0
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.online.init()
|
||||
self.vac.reset_states()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
||||
self.is_currently_final = False
|
||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.buffer_offset = 0 # in frames
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
||||
return self.online.get_audio_buffer_end_time()
|
||||
|
||||
def clear_buffer(self):
|
||||
self.buffer_offset += len(self.audio_buffer)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
"""
|
||||
Process an incoming small audio chunk:
|
||||
- run VAD on the chunk,
|
||||
- decide whether to send the audio to the online ASR processor immediately,
|
||||
- and/or to mark the current utterance as finished.
|
||||
"""
|
||||
self.last_input_audio_stream_end_time = audio_stream_end_time
|
||||
res = self.vac(audio)
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
if res is not None:
|
||||
# VAD returned a result; adjust the frame number
|
||||
frame = list(res.values())[0] - self.buffer_offset
|
||||
if "start" in res and "end" not in res:
|
||||
self.status = "voice"
|
||||
send_audio = self.audio_buffer[frame:]
|
||||
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.clear_buffer()
|
||||
elif "end" in res and "start" not in res:
|
||||
self.status = "nonvoice"
|
||||
send_audio = self.audio_buffer[:frame]
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.is_currently_final = True
|
||||
self.clear_buffer()
|
||||
else:
|
||||
beg = res["start"] - self.buffer_offset
|
||||
end = res["end"] - self.buffer_offset
|
||||
self.status = "nonvoice"
|
||||
send_audio = self.audio_buffer[beg:end]
|
||||
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
||||
self.online.insert_audio_chunk(send_audio)
|
||||
self.current_online_chunk_buffer_size += len(send_audio)
|
||||
self.is_currently_final = True
|
||||
self.clear_buffer()
|
||||
else:
|
||||
if self.status == "voice":
|
||||
self.online.insert_audio_chunk(self.audio_buffer)
|
||||
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
||||
self.clear_buffer()
|
||||
else:
|
||||
# Keep 1 second worth of audio in case VAD later detects voice,
|
||||
# but trim to avoid unbounded memory usage.
|
||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Depending on the VAD status and the amount of accumulated audio,
|
||||
process the current audio chunk.
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if self.is_currently_final:
|
||||
return self.finish()
|
||||
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
return self.online.process_iter()
|
||||
else:
|
||||
logger.debug("No online update, only VAD")
|
||||
return [], self.last_input_audio_stream_end_time
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Finish processing by flushing any remaining text.
|
||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||
"""
|
||||
result_tokens, processed_upto = self.online.finish()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.is_currently_final = False
|
||||
return result_tokens, processed_upto
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer in string format.
|
||||
"""
|
||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
||||
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError("SimulStreaming dependencies are not available.")
|
||||
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.init()
|
||||
|
||||
# buffer does not work yet
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing state."""
|
||||
self.audio_chunks = []
|
||||
self.offset = offset if offset is not None else 0.0
|
||||
self.is_last = False
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
self.cumulative_audio_duration = 0.0
|
||||
self.last_audio_stream_end_time = self.offset
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio buffer."""
|
||||
return self.end
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.audio_chunks.append(audio_tensor)
|
||||
|
||||
# Update timing
|
||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||
self.cumulative_audio_duration += chunk_duration
|
||||
|
||||
if audio_stream_end_time is not None:
|
||||
self.last_audio_stream_end_time = audio_stream_end_time
|
||||
self.end = audio_stream_end_time
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context).
|
||||
SimulStreaming handles prompting internally, so we return empty strings.
|
||||
"""
|
||||
return "", ""
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer content.
|
||||
"""
|
||||
buffer_end = self.end if hasattr(self, 'end') else None
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=buffer_end,
|
||||
text=self.buffer_content,
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
# From the simulstreaming repo. self.model to self.asr.model
|
||||
pr = generation["progress"]
|
||||
if "result" not in generation:
|
||||
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
|
||||
else:
|
||||
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
||||
|
||||
frames = [p["most_attended_frames"][0] for p in pr]
|
||||
tokens = tokens.copy()
|
||||
ret = []
|
||||
for sw,st in zip(split_words,split_tokens):
|
||||
b = None
|
||||
for stt in st:
|
||||
t,f = tokens.pop(0), frames.pop(0)
|
||||
if t != stt:
|
||||
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
||||
if b is None:
|
||||
b = f
|
||||
e = f
|
||||
out = (b*0.02, e*0.02, sw)
|
||||
ret.append(out)
|
||||
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
||||
return ret
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if not self.audio_chunks:
|
||||
return [], self.end
|
||||
|
||||
try:
|
||||
# concatenate all audio chunks
|
||||
if len(self.audio_chunks) == 1:
|
||||
audio = self.audio_chunks[0]
|
||||
else:
|
||||
audio = torch.cat(self.audio_chunks, dim=0)
|
||||
|
||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
||||
self.processed_audio_duration += audio_duration
|
||||
|
||||
self.audio_chunks = []
|
||||
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
self.asr.model.insert_audio(audio)
|
||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
text = self.asr.model.tokenizer.decode(tokens)
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
return new_tokens, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming processing error: {e}")
|
||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
||||
return [], self.end
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
logger.debug("SimulStreaming finish() called")
|
||||
self.is_last = True
|
||||
final_tokens, final_time = self.process_iter()
|
||||
self.is_last = False
|
||||
return final_tokens, final_time
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
"""Concatenate tokens into a Transcript object."""
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
useless but kept for compatibility
|
||||
"""
|
||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
||||
pass
|
||||
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
Create simple sentences.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
sentence = Sentence(
|
||||
start=tokens[0].start,
|
||||
end=tokens[-1].end,
|
||||
text=full_text
|
||||
)
|
||||
return [sentence]
|
||||
return Transcript(start, end, text)
|
||||
202
whisperlivekit/local_agreement/whisper_online.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import platform
|
||||
import time
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
|
||||
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
lora_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
has_pytorch = False
|
||||
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
model_info = detect_model_format(resolved_root)
|
||||
has_mlx_weights = model_info.compatible_whisper_mlx
|
||||
has_fw_weights = model_info.compatible_faster_whisper
|
||||
has_pytorch = model_info.has_pytorch
|
||||
else:
|
||||
# Single file provided
|
||||
has_pytorch = True
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
backend_choice = _normalize_backend_choice(
|
||||
backend_choice,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
)
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
elif backend_choice == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
if custom_reference and not has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_override,
|
||||
lora_path=lora_path if backend_choice == "whisper" else None,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
asr.transcribe_kargs["task"] = "translate"
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
asr.backend_choice = backend_choice
|
||||
return asr
|
||||
|
||||
|
||||
def _normalize_backend_choice(
|
||||
preferred_backend,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
):
|
||||
backend_choice = preferred_backend
|
||||
|
||||
if backend_choice == "auto":
|
||||
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||
return "mlx-whisper"
|
||||
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
if backend_choice == "mlx-whisper":
|
||||
if not mlx_backend_available():
|
||||
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||
if resolved_root is not None and not has_mlx_weights:
|
||||
raise FileNotFoundError(
|
||||
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||
)
|
||||
if platform.system() != "Darwin":
|
||||
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
if not faster_backend_available():
|
||||
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||
if resolved_root is not None and not has_fw_weights:
|
||||
raise FileNotFoundError(
|
||||
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||
)
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "whisper":
|
||||
return backend_choice
|
||||
|
||||
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||
156
whisperlivekit/metrics.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Lightweight ASR evaluation metrics — no external dependencies.
|
||||
|
||||
Provides WER (Word Error Rate) computation via word-level Levenshtein distance,
|
||||
text normalization, and word-level timestamp accuracy metrics with greedy alignment.
|
||||
"""
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text for WER comparison: lowercase, strip punctuation, collapse whitespace."""
|
||||
text = text.lower()
|
||||
# Normalize unicode (e.g., accented chars to composed form)
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
# Remove punctuation (keep letters, numbers, spaces, hyphens within words)
|
||||
text = re.sub(r"[^\w\s\-']", " ", text)
|
||||
# Collapse whitespace
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def compute_wer(reference: str, hypothesis: str) -> Dict:
|
||||
"""Compute Word Error Rate using word-level Levenshtein edit distance.
|
||||
|
||||
Args:
|
||||
reference: Ground truth transcription.
|
||||
hypothesis: Predicted transcription.
|
||||
|
||||
Returns:
|
||||
Dict with keys: wer, substitutions, insertions, deletions, ref_words, hyp_words.
|
||||
WER can exceed 1.0 if there are more errors than reference words.
|
||||
"""
|
||||
ref_words = normalize_text(reference).split()
|
||||
hyp_words = normalize_text(hypothesis).split()
|
||||
|
||||
n = len(ref_words)
|
||||
m = len(hyp_words)
|
||||
|
||||
if n == 0:
|
||||
return {
|
||||
"wer": 0.0 if m == 0 else float(m),
|
||||
"substitutions": 0,
|
||||
"insertions": m,
|
||||
"deletions": 0,
|
||||
"ref_words": 0,
|
||||
"hyp_words": m,
|
||||
}
|
||||
|
||||
# DP table: dp[i][j] = (edit_distance, substitutions, insertions, deletions)
|
||||
dp = [[(0, 0, 0, 0) for _ in range(m + 1)] for _ in range(n + 1)]
|
||||
|
||||
for i in range(1, n + 1):
|
||||
dp[i][0] = (i, 0, 0, i)
|
||||
for j in range(1, m + 1):
|
||||
dp[0][j] = (j, 0, j, 0)
|
||||
|
||||
for i in range(1, n + 1):
|
||||
for j in range(1, m + 1):
|
||||
if ref_words[i - 1] == hyp_words[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1]
|
||||
else:
|
||||
sub = dp[i - 1][j - 1]
|
||||
ins = dp[i][j - 1]
|
||||
dele = dp[i - 1][j]
|
||||
|
||||
sub_cost = (sub[0] + 1, sub[1] + 1, sub[2], sub[3])
|
||||
ins_cost = (ins[0] + 1, ins[1], ins[2] + 1, ins[3])
|
||||
del_cost = (dele[0] + 1, dele[1], dele[2], dele[3] + 1)
|
||||
|
||||
dp[i][j] = min(sub_cost, del_cost, ins_cost, key=lambda x: x[0])
|
||||
|
||||
dist, subs, ins, dels = dp[n][m]
|
||||
return {
|
||||
"wer": dist / n,
|
||||
"substitutions": subs,
|
||||
"insertions": ins,
|
||||
"deletions": dels,
|
||||
"ref_words": n,
|
||||
"hyp_words": m,
|
||||
}
|
||||
|
||||
|
||||
def compute_timestamp_accuracy(
|
||||
predicted: List[Dict],
|
||||
reference: List[Dict],
|
||||
) -> Dict:
|
||||
"""Compute timestamp accuracy by aligning predicted words to reference words.
|
||||
|
||||
Uses greedy left-to-right alignment on normalized text. For each matched pair,
|
||||
computes the start-time delta (predicted - reference).
|
||||
|
||||
Args:
|
||||
predicted: List of dicts with keys: word, start, end.
|
||||
reference: List of dicts with keys: word, start, end.
|
||||
|
||||
Returns:
|
||||
Dict with keys: mae_start, max_delta_start, median_delta_start,
|
||||
n_matched, n_ref, n_pred. Returns None values if no matches found.
|
||||
"""
|
||||
if not predicted or not reference:
|
||||
return {
|
||||
"mae_start": None,
|
||||
"max_delta_start": None,
|
||||
"median_delta_start": None,
|
||||
"n_matched": 0,
|
||||
"n_ref": len(reference),
|
||||
"n_pred": len(predicted),
|
||||
}
|
||||
|
||||
# Normalize words for matching
|
||||
pred_norm = [normalize_text(p["word"]) for p in predicted]
|
||||
ref_norm = [normalize_text(r["word"]) for r in reference]
|
||||
|
||||
# Greedy left-to-right alignment
|
||||
deltas_start = []
|
||||
ref_idx = 0
|
||||
for p_idx, p_word in enumerate(pred_norm):
|
||||
if not p_word:
|
||||
continue
|
||||
# Scan forward in reference to find a match (allow small skips)
|
||||
search_limit = min(ref_idx + 3, len(ref_norm))
|
||||
for r_idx in range(ref_idx, search_limit):
|
||||
if ref_norm[r_idx] == p_word:
|
||||
delta = predicted[p_idx]["start"] - reference[r_idx]["start"]
|
||||
deltas_start.append(delta)
|
||||
ref_idx = r_idx + 1
|
||||
break
|
||||
|
||||
if not deltas_start:
|
||||
return {
|
||||
"mae_start": None,
|
||||
"max_delta_start": None,
|
||||
"median_delta_start": None,
|
||||
"n_matched": 0,
|
||||
"n_ref": len(reference),
|
||||
"n_pred": len(predicted),
|
||||
}
|
||||
|
||||
abs_deltas = [abs(d) for d in deltas_start]
|
||||
sorted_abs = sorted(abs_deltas)
|
||||
n = len(sorted_abs)
|
||||
if n % 2 == 1:
|
||||
median = sorted_abs[n // 2]
|
||||
else:
|
||||
median = (sorted_abs[n // 2 - 1] + sorted_abs[n // 2]) / 2
|
||||
|
||||
return {
|
||||
"mae_start": sum(abs_deltas) / len(abs_deltas),
|
||||
"max_delta_start": max(abs_deltas),
|
||||
"median_delta_start": median,
|
||||
"n_matched": len(deltas_start),
|
||||
"n_ref": len(reference),
|
||||
"n_pred": len(predicted),
|
||||
}
|
||||
84
whisperlivekit/metrics_collector.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Lightweight runtime metrics for AudioProcessor sessions.
|
||||
|
||||
Zero external dependencies. Negligible overhead when not queried —
|
||||
just integer increments and list appends during normal operation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionMetrics:
|
||||
"""Per-session metrics collected by AudioProcessor."""
|
||||
|
||||
session_start: float = 0.0
|
||||
total_audio_duration_s: float = 0.0
|
||||
total_processing_time_s: float = 0.0
|
||||
|
||||
# Chunk / call counters
|
||||
n_chunks_received: int = 0
|
||||
n_transcription_calls: int = 0
|
||||
n_tokens_produced: int = 0
|
||||
n_responses_sent: int = 0
|
||||
|
||||
# Per-call ASR latency (seconds)
|
||||
transcription_durations: List[float] = field(default_factory=list)
|
||||
|
||||
# Silence
|
||||
n_silence_events: int = 0
|
||||
total_silence_duration_s: float = 0.0
|
||||
|
||||
# --- Computed properties ---
|
||||
|
||||
@property
|
||||
def rtf(self) -> float:
|
||||
"""Real-time factor: processing_time / audio_duration."""
|
||||
if self.total_audio_duration_s <= 0:
|
||||
return 0.0
|
||||
return self.total_processing_time_s / self.total_audio_duration_s
|
||||
|
||||
@property
|
||||
def avg_latency_ms(self) -> float:
|
||||
"""Average per-call ASR latency in milliseconds."""
|
||||
if not self.transcription_durations:
|
||||
return 0.0
|
||||
return (sum(self.transcription_durations) / len(self.transcription_durations)) * 1000
|
||||
|
||||
@property
|
||||
def p95_latency_ms(self) -> float:
|
||||
"""95th percentile per-call ASR latency in milliseconds."""
|
||||
if not self.transcription_durations:
|
||||
return 0.0
|
||||
sorted_d = sorted(self.transcription_durations)
|
||||
idx = int(len(sorted_d) * 0.95)
|
||||
idx = min(idx, len(sorted_d) - 1)
|
||||
return sorted_d[idx] * 1000
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Serialize to a plain dict (JSON-safe)."""
|
||||
return {
|
||||
"session_start": self.session_start,
|
||||
"total_audio_duration_s": round(self.total_audio_duration_s, 3),
|
||||
"total_processing_time_s": round(self.total_processing_time_s, 3),
|
||||
"rtf": round(self.rtf, 3),
|
||||
"n_chunks_received": self.n_chunks_received,
|
||||
"n_transcription_calls": self.n_transcription_calls,
|
||||
"n_tokens_produced": self.n_tokens_produced,
|
||||
"n_responses_sent": self.n_responses_sent,
|
||||
"avg_latency_ms": round(self.avg_latency_ms, 2),
|
||||
"p95_latency_ms": round(self.p95_latency_ms, 2),
|
||||
"n_silence_events": self.n_silence_events,
|
||||
"total_silence_duration_s": round(self.total_silence_duration_s, 3),
|
||||
}
|
||||
|
||||
def log_summary(self) -> None:
|
||||
"""Emit a structured log line summarising the session."""
|
||||
self.total_processing_time_s = sum(self.transcription_durations)
|
||||
d = self.to_dict()
|
||||
d["session_elapsed_s"] = round(time.time() - self.session_start, 3) if self.session_start else 0
|
||||
logger.info(f"SESSION_METRICS {d}")
|
||||
17
whisperlivekit/model_mapping.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Shared MLX model name mapping used by both SimulStreaming and LocalAgreement backends."""
|
||||
|
||||
MLX_MODEL_MAPPING = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
215
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about detected model format and files in a directory."""
|
||||
path: Optional[Path] = None
|
||||
pytorch_files: List[Path] = field(default_factory=list)
|
||||
compatible_whisper_mlx: bool = False
|
||||
compatible_faster_whisper: bool = False
|
||||
|
||||
@property
|
||||
def has_pytorch(self) -> bool:
|
||||
return len(self.pytorch_files) > 0
|
||||
|
||||
@property
|
||||
def is_sharded(self) -> bool:
|
||||
return len(self.pytorch_files) > 1
|
||||
|
||||
@property
|
||||
def primary_pytorch_file(self) -> Optional[Path]:
|
||||
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
||||
if not self.pytorch_files:
|
||||
return None
|
||||
return self.pytorch_files[0]
|
||||
|
||||
|
||||
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
|
||||
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
|
||||
|
||||
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
|
||||
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
|
||||
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
|
||||
|
||||
|
||||
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||
"""
|
||||
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
||||
|
||||
CTranslate2 models have specific companion files that distinguish them
|
||||
from PyTorch .bin files.
|
||||
"""
|
||||
n_indicators = 0
|
||||
for indicator in CT2_INDICATOR_FILES: #test 1
|
||||
if (directory / indicator).exists():
|
||||
n_indicators += 1
|
||||
|
||||
if n_indicators == 0:
|
||||
return False
|
||||
|
||||
config_path = directory / "config.json" #test 2
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
if config.get("model_type") == "whisper": #test 2
|
||||
return False
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||
"""
|
||||
Collect all PyTorch checkpoint files from a directory.
|
||||
|
||||
Handles:
|
||||
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
||||
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
||||
- Index-based sharded models (reads index file to find shards)
|
||||
|
||||
Returns files sorted appropriately (shards in order, or single file).
|
||||
"""
|
||||
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
||||
index_path = directory / index_name
|
||||
if index_path.exists():
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
index_data = json.load(f)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
if weight_map:
|
||||
shard_names = sorted(set(weight_map.values()))
|
||||
shards = [directory / name for name in shard_names if (directory / name).exists()]
|
||||
if shards:
|
||||
return shards
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
sharded_groups = {}
|
||||
single_files = {}
|
||||
|
||||
for file in directory.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename.startswith("adapter_"):
|
||||
continue
|
||||
|
||||
match = SHARDED_PATTERN.match(filename)
|
||||
if match:
|
||||
base_name, shard_idx, total_shards, ext = match.groups()
|
||||
key = (base_name, ext, int(total_shards))
|
||||
if key not in sharded_groups:
|
||||
sharded_groups[key] = []
|
||||
sharded_groups[key].append((int(shard_idx), file))
|
||||
continue
|
||||
|
||||
if filename == "model.safetensors":
|
||||
single_files[0] = file # Highest priority
|
||||
elif filename == "pytorch_model.bin":
|
||||
single_files[1] = file
|
||||
elif suffix == ".pt":
|
||||
single_files[2] = file
|
||||
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
||||
single_files[3] = file
|
||||
|
||||
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
||||
if len(shards) == total_shards:
|
||||
return [path for _, path in sorted(shards)]
|
||||
|
||||
for priority in sorted(single_files.keys()):
|
||||
return [single_files[priority]]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
||||
"""
|
||||
Detect the model format in a given path.
|
||||
|
||||
This function analyzes a file or directory to determine:
|
||||
- What PyTorch checkpoint files are available (including sharded models)
|
||||
- Whether the directory contains MLX Whisper weights
|
||||
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
||||
|
||||
Args:
|
||||
model_path: Path to a model file or directory
|
||||
|
||||
Returns:
|
||||
ModelInfo with detected format information
|
||||
"""
|
||||
path = Path(model_path)
|
||||
info = ModelInfo(path=path)
|
||||
|
||||
if path.is_file():
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".pt", ".safetensors", ".bin"}:
|
||||
info.pytorch_files = [path]
|
||||
return info
|
||||
|
||||
if not path.is_dir():
|
||||
return info
|
||||
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
|
||||
if filename in MLX_WHISPER_MARKERS:
|
||||
info.compatible_whisper_mlx = True
|
||||
|
||||
if filename in FASTER_WHISPER_MARKERS:
|
||||
if _is_ct2_model_bin(path, filename):
|
||||
info.compatible_faster_whisper = True
|
||||
|
||||
info.pytorch_files = _collect_pytorch_files(path)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
This is a compatibility wrapper around detect_model_format().
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
|
||||
"""
|
||||
info = detect_model_format(model_path)
|
||||
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Return a local path for the provided model reference.
|
||||
|
||||
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||
and downloaded via snapshot_download.
|
||||
"""
|
||||
path = Path(model_path).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||
return downloaded_path
|
||||
@@ -1,6 +1,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
@@ -20,7 +21,7 @@ def parse_args():
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
If empty, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -58,23 +59,38 @@ def parse_args():
|
||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
type=str,
|
||||
default="sortformer",
|
||||
choices=["sortformer", "diart"],
|
||||
help="The diarization backend to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-punctuation-split",
|
||||
action="store_true",
|
||||
help="Disable the split parameter.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=0.1,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
default="base",
|
||||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
@@ -90,32 +106,55 @@ def parse_args():
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="lora_path",
|
||||
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
dest='lan',
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
"--direct-english-translation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Whisper to directly translate to english.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-language",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
default="",
|
||||
dest="target_language",
|
||||
help="Target language for translation. Not functional yet.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend-policy",
|
||||
type=str,
|
||||
default="simulstreaming",
|
||||
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral", "voxtral-mlx"],
|
||||
help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for HF Transformers Voxtral (CUDA/CPU/MPS). Use 'voxtral-mlx' for native MLX Voxtral on Apple Silicon.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
"--no-vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
help="Disable VAC = voice activity controller.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
@@ -150,9 +189,30 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
|
||||
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
|
||||
parser.add_argument(
|
||||
"--pcm-input",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||
)
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--disable-fast-encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="disable_fast_encoder",
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--custom-alignment-heads",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
@@ -242,12 +302,28 @@ def parse_args():
|
||||
dest="model_path",
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
args.vad = not args.no_vad
|
||||
args.vad = not args.no_vad
|
||||
args.vac = not args.no_vac
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
return args
|
||||
delattr(args, 'no_vac')
|
||||
|
||||
from whisperlivekit.config import WhisperLiveKitConfig
|
||||
return WhisperLiveKitConfig.from_namespace(args)
|
||||
|
||||
326
whisperlivekit/silero_vad_iterator.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
def is_onnx_available() -> bool:
|
||||
"""Check if onnxruntime is installed."""
|
||||
try:
|
||||
import onnxruntime
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||
"""Load a JIT model from file."""
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class OnnxSession():
|
||||
"""
|
||||
Shared ONNX session for Silero VAD model (stateless).
|
||||
"""
|
||||
|
||||
def __init__(self, path, force_onnx_cpu=False):
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.path = path
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
|
||||
class OnnxWrapper():
|
||||
"""
|
||||
ONNX Runtime wrapper for Silero VAD model with per-instance state.
|
||||
"""
|
||||
|
||||
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
|
||||
self._shared_session = session
|
||||
self.sample_rates = session.sample_rates
|
||||
self.reset_states()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._shared_session.session
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_sr) and (self._last_sr != sr):
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling rate {sr}. Supported: {self.sample_rates} (with sample sizes 256 for 8000, 512 for 16000)")
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
|
||||
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
|
||||
"""Get the path to the ONNX model file."""
|
||||
available_ops = [15, 16]
|
||||
if opset_version not in available_ops:
|
||||
raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}')
|
||||
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'silero_vad_models'
|
||||
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
|
||||
"""
|
||||
Load a shared ONNX session for Silero VAD.
|
||||
"""
|
||||
path = _get_onnx_model_path(model_path, opset_version)
|
||||
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
|
||||
|
||||
|
||||
def load_jit_vad(model_path: str = None):
|
||||
"""
|
||||
Load Silero VAD model in JIT format.
|
||||
"""
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'silero_vad_models'
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
sampling_rate: int (default - 16000)
|
||||
Currently silero VAD models support 8000 and 16000 sample rates
|
||||
|
||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||
|
||||
speech_pad_ms: int (default - 30 milliseconds)
|
||||
Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.threshold = threshold
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
self.reset_states()
|
||||
|
||||
def reset_states(self):
|
||||
|
||||
self.model.reset_states()
|
||||
self.triggered = False
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
time_resolution: int (default - 1)
|
||||
time resolution of speech coordinates when requested as seconds
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
try:
|
||||
x = torch.Tensor(x)
|
||||
except (ValueError, TypeError, RuntimeError) as exc:
|
||||
raise TypeError("Audio cannot be cast to tensor. Cast it manually") from exc
|
||||
|
||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||
self.current_sample += window_size_samples
|
||||
|
||||
speech_prob = self.model(x, self.sampling_rate).item()
|
||||
|
||||
if (speech_prob >= self.threshold) and self.temp_end:
|
||||
self.temp_end = 0
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
self.temp_end = self.current_sample
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
"""
|
||||
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
super().reset_states()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
self.buffer = np.append(self.buffer, x)
|
||||
ret = None
|
||||
while len(self.buffer) >= 512:
|
||||
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
||||
self.buffer = self.buffer[512:]
|
||||
if ret is None:
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r:
|
||||
ret["start"] = r["start"]
|
||||
if "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# vad = FixedVADIterator(load_jit_vad())
|
||||
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
|
||||
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 511 samples: {result}")
|
||||
0
whisperlivekit/silero_vad_models/__init__.py
Normal file
BIN
whisperlivekit/silero_vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/silero_vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/silero_vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/silero_vad_models/silero_vad_half.onnx
Normal file
@@ -0,0 +1,6 @@
|
||||
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||
|
||||
__all__ = [
|
||||
"SimulStreamingASR",
|
||||
"SimulStreamingOnlineProcessor",
|
||||
]
|
||||
|
||||
552
whisperlivekit/simul_whisper/align_att_base.py
Normal file
@@ -0,0 +1,552 @@
|
||||
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
|
||||
from .config import AlignAttConfig
|
||||
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlignAttBase(ABC):
|
||||
"""
|
||||
Abstract base class for AlignAtt streaming decoders.
|
||||
|
||||
Provides shared logic for both PyTorch and MLX implementations:
|
||||
- Properties (speaker, global_time_offset)
|
||||
- Pure-Python methods (warmup, trim_context, refresh_segment, etc.)
|
||||
- Template infer() with abstract hooks for tensor-specific operations
|
||||
- Post-decode logic (token splitting, timestamped word building)
|
||||
|
||||
Subclasses must implement ~20 abstract methods for tensor-specific ops.
|
||||
"""
|
||||
|
||||
# === Properties ===
|
||||
|
||||
@property
|
||||
def speaker(self):
|
||||
return self.state.speaker
|
||||
|
||||
@speaker.setter
|
||||
def speaker(self, value):
|
||||
self.state.speaker = value
|
||||
|
||||
@property
|
||||
def global_time_offset(self):
|
||||
return self.state.global_time_offset
|
||||
|
||||
@global_time_offset.setter
|
||||
def global_time_offset(self, value):
|
||||
self.state.global_time_offset = value
|
||||
|
||||
# === Constructor helpers ===
|
||||
|
||||
def _base_init(self, cfg: AlignAttConfig, model):
|
||||
"""Common initialization — call from subclass __init__."""
|
||||
self.model = model
|
||||
self.cfg = cfg
|
||||
self.decode_options = DecodingOptions(
|
||||
language=cfg.language,
|
||||
without_timestamps=True,
|
||||
task=cfg.task,
|
||||
)
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.max_text_len = model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(model.decoder.blocks)
|
||||
if cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = cfg.max_context_tokens
|
||||
|
||||
def _init_state_common(self, cfg: AlignAttConfig):
|
||||
"""Common state initialization — call from subclass _init_state."""
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
self.state.tokenizer = self.tokenizer
|
||||
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.state.global_time_offset = 0.0
|
||||
self.state.last_attend_frame = -cfg.rewind_threshold
|
||||
self.state.speaker = -1
|
||||
|
||||
# === Shared concrete methods ===
|
||||
|
||||
def warmup(self, audio):
|
||||
try:
|
||||
self.insert_audio(audio)
|
||||
self.infer(is_last=True)
|
||||
self.refresh_segment(complete=True)
|
||||
logger.info("Model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Model warmup failed: {e}")
|
||||
|
||||
def create_tokenizer(self, language=None):
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=self.tokenizer_is_multilingual,
|
||||
language=language,
|
||||
num_languages=self.model.num_languages,
|
||||
task=self.decode_options.task,
|
||||
)
|
||||
self.state.tokenizer = self.tokenizer
|
||||
|
||||
def trim_context(self):
|
||||
logger.info("Trimming context")
|
||||
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
||||
logger.info(f"Context text: {self.state.context.as_text()}")
|
||||
l = sum(t.shape[1] for t in self.state.tokens) + c
|
||||
after = 0 if self.cfg.static_init_prompt is None else len(self.cfg.static_init_prompt)
|
||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||
t = self.state.context.trim_words(after=after)
|
||||
l -= t
|
||||
c -= t
|
||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if t == 0:
|
||||
break
|
||||
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
||||
|
||||
def refresh_segment(self, complete=False):
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.state.context}")
|
||||
if not complete and len(self.state.segments) > 2:
|
||||
self.state.segments = self.state.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
self.state.segments = []
|
||||
self.state.log_segments += 1
|
||||
self.state.pending_incomplete_tokens = []
|
||||
self.state.pending_retries = 0
|
||||
|
||||
def segments_len(self):
|
||||
return sum(s.shape[0] for s in self.state.segments) / 16000
|
||||
|
||||
def _apply_minseglen(self):
|
||||
segments_len = self.segments_len()
|
||||
if segments_len < self.cfg.audio_min_len:
|
||||
logger.debug("waiting for next segment")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _clean_cache(self):
|
||||
self.state.clean_cache()
|
||||
|
||||
def debug_print_tokens(self, tokens):
|
||||
for i in range(min(self.cfg.beam_size, tokens.shape[0])):
|
||||
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
||||
|
||||
# === Language detection ===
|
||||
|
||||
def _detect_language_if_needed(self, encoder_feature):
|
||||
if (
|
||||
self.cfg.language == "auto"
|
||||
and self.state.detected_language is None
|
||||
and self.state.first_timestamp
|
||||
):
|
||||
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.state.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}")
|
||||
|
||||
# === Template infer() ===
|
||||
|
||||
def infer(self, is_last=False):
|
||||
"""Main inference — template method calling abstract hooks for tensor ops."""
|
||||
new_segment = True
|
||||
|
||||
if len(self.state.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
return []
|
||||
|
||||
input_segments = self._concat_segments()
|
||||
encoder_feature, content_mel_len = self._encode(input_segments)
|
||||
self._evaluate(encoder_feature)
|
||||
|
||||
self._detect_language_if_needed(encoder_feature)
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
sum_logprobs = self._init_sum_logprobs()
|
||||
completed = False
|
||||
token_len_before = current_tokens.shape[1]
|
||||
l_absolute_timestamps = []
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced = 0
|
||||
most_attended_frame = None
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||
tokens_produced += 1
|
||||
if tokens_produced > max_tokens:
|
||||
logger.warning(
|
||||
f"[Loop Detection] Too many tokens ({tokens_produced}) "
|
||||
f"for {audio_duration_s:.2f}s audio. Breaking."
|
||||
)
|
||||
current_tokens = current_tokens[:, :token_len_before]
|
||||
break
|
||||
|
||||
tokens_for_logits = current_tokens if new_segment else current_tokens[:, -1:]
|
||||
logits, cross_attns = self._get_logits_and_cross_attn(
|
||||
tokens_for_logits, encoder_feature
|
||||
)
|
||||
self._evaluate(logits)
|
||||
|
||||
accumulated_cross_attns.append(cross_attns)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self._check_no_speech(logits):
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if new_segment:
|
||||
logits = self._suppress_blank_tokens(logits)
|
||||
new_segment = False
|
||||
|
||||
logits = self._apply_token_suppression(logits)
|
||||
logits = self._apply_dry_penalty(logits, current_tokens)
|
||||
current_tokens, completed = self._update_tokens(
|
||||
current_tokens, logits, sum_logprobs
|
||||
)
|
||||
self._evaluate(current_tokens)
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
attn = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
|
||||
frames_list, most_attended_frame = self._get_attended_frames(attn)
|
||||
|
||||
absolute_timestamps = [
|
||||
(frame * 0.02 + self.state.cumulative_time_offset)
|
||||
for frame in frames_list
|
||||
]
|
||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
|
||||
|
||||
if completed:
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# Rewind check
|
||||
if (
|
||||
not is_last
|
||||
and self.state.last_attend_frame - most_attended_frame
|
||||
> self.cfg.rewind_threshold
|
||||
):
|
||||
if current_tokens.shape[1] > 1 and self._is_special_token(current_tokens):
|
||||
logger.debug("omit rewinding from special tokens")
|
||||
self.state.last_attend_frame = most_attended_frame
|
||||
else:
|
||||
logger.debug(
|
||||
f"[rewind detected] current: {most_attended_frame}, "
|
||||
f"last: {self.state.last_attend_frame}"
|
||||
)
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
current_tokens = self._rewind_tokens()
|
||||
break
|
||||
else:
|
||||
self.state.last_attend_frame = most_attended_frame
|
||||
|
||||
if content_mel_len - most_attended_frame <= (
|
||||
4 if is_last else self.cfg.frame_threshold
|
||||
):
|
||||
logger.debug(
|
||||
f"attention reaches the end: {most_attended_frame}/{content_mel_len}"
|
||||
)
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# Post-decode: split tokens and build timestamped words
|
||||
tokens_to_split = self._tokens_to_list(current_tokens, token_len_before)
|
||||
if self.state.pending_incomplete_tokens:
|
||||
logger.debug(
|
||||
f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} "
|
||||
f"pending tokens: {self.state.pending_incomplete_tokens}"
|
||||
)
|
||||
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
|
||||
|
||||
new_hypothesis, split_words, split_tokens = self._split_tokens(
|
||||
tokens_to_split, fire_detected, is_last
|
||||
)
|
||||
|
||||
new_tokens_tensor = self._make_new_tokens_tensor(new_hypothesis)
|
||||
self.state.tokens.append(new_tokens_tensor)
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
||||
self.state.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
timestamped_words = self._build_timestamped_words(
|
||||
split_words, split_tokens, l_absolute_timestamps
|
||||
)
|
||||
self._handle_pending_tokens(split_words, split_tokens)
|
||||
|
||||
return timestamped_words
|
||||
|
||||
# === Post-decode shared helpers ===
|
||||
|
||||
def _split_tokens(self, tokens_list, fire_detected, is_last):
|
||||
"""Split token list into words. Returns (hypothesis, split_words, split_tokens)."""
|
||||
if fire_detected or is_last:
|
||||
new_hypothesis = tokens_list
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_list)
|
||||
if len(split_words) > 1:
|
||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||
else:
|
||||
new_hypothesis = []
|
||||
return new_hypothesis, split_words, split_tokens
|
||||
|
||||
def _build_timestamped_words(self, split_words, split_tokens, l_absolute_timestamps):
|
||||
"""Build list of timestamped ASRToken from split words."""
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
if replacement_char in word:
|
||||
cleaned = word.replace(replacement_char, "")
|
||||
if not cleaned.strip():
|
||||
logger.debug(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
||||
timestamp_idx += len(word_tokens)
|
||||
continue
|
||||
logger.debug(f"[UTF-8 Filter] Cleaned {repr(word)} -> {repr(cleaned)}")
|
||||
word = cleaned
|
||||
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except IndexError:
|
||||
logger.warning(
|
||||
f"Timestamp index {timestamp_idx} out of range, using last timestamp"
|
||||
)
|
||||
current_timestamp = (
|
||||
l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
|
||||
)
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=round(current_timestamp, 2),
|
||||
end=round(current_timestamp + 0.1, 2),
|
||||
text=word,
|
||||
speaker=self.state.speaker,
|
||||
detected_language=self.state.detected_language,
|
||||
).with_offset(self.state.global_time_offset)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
return timestamped_words
|
||||
|
||||
def _handle_pending_tokens(self, split_words, split_tokens):
|
||||
"""Handle incomplete UTF-8 tokens for next chunk."""
|
||||
MAX_PENDING_TOKENS = 10
|
||||
MAX_PENDING_RETRIES = 2
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
if split_words and replacement_char in split_words[-1]:
|
||||
self.state.pending_retries += 1
|
||||
if self.state.pending_retries > MAX_PENDING_RETRIES:
|
||||
logger.warning(
|
||||
f"[UTF-8 Fix] Dropping {len(split_tokens[-1])} incomplete tokens "
|
||||
f"after {MAX_PENDING_RETRIES} retries (won't resolve)"
|
||||
)
|
||||
self.state.pending_incomplete_tokens = []
|
||||
self.state.pending_retries = 0
|
||||
elif len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
||||
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.debug(
|
||||
f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} "
|
||||
f"incomplete tokens for next chunk (retry {self.state.pending_retries})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens "
|
||||
f"(exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)"
|
||||
)
|
||||
self.state.pending_incomplete_tokens = []
|
||||
self.state.pending_retries = 0
|
||||
else:
|
||||
self.state.pending_incomplete_tokens = []
|
||||
self.state.pending_retries = 0
|
||||
|
||||
# === Repetition penalty ===
|
||||
|
||||
def _apply_dry_penalty(self, logits, current_tokens):
|
||||
"""DRY penalty v0: penalize tokens that would extend a verbatim repetition.
|
||||
See https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
|
||||
Scans the decoded sequence for positions where the current suffix already
|
||||
appeared --> for each such match, the token that followed it in the past is
|
||||
penalised exponentially with the match length
|
||||
"""
|
||||
eot = self.tokenizer.eot
|
||||
seq = current_tokens[0].tolist()
|
||||
if len(seq) < 5:
|
||||
return logits
|
||||
|
||||
last = seq[-1]
|
||||
if last >= eot:
|
||||
return logits
|
||||
|
||||
penalties = {}
|
||||
for i in range(len(seq) - 2, -1, -1):
|
||||
if seq[i] != last:
|
||||
continue
|
||||
next_tok = seq[i + 1]
|
||||
if next_tok >= eot:
|
||||
continue
|
||||
|
||||
length = 1
|
||||
while length < 50:
|
||||
j, k = i - length, len(seq) - 1 - length
|
||||
if j < 0 or k <= i:
|
||||
break
|
||||
if seq[j] != seq[k] or seq[j] >= eot:
|
||||
break
|
||||
length += 1
|
||||
|
||||
if next_tok not in penalties or length > penalties[next_tok]:
|
||||
penalties[next_tok] = length
|
||||
|
||||
if penalties:
|
||||
max_len = max(penalties.values())
|
||||
if max_len >= 4:
|
||||
logger.debug(f"[DRY] penalising {len(penalties)} tokens (longest match: {max_len})")
|
||||
for tok, length in penalties.items():
|
||||
if length >= 2:
|
||||
logits[:, tok] = logits[:, tok] - 1.0 * 2.0 ** (length - 2)
|
||||
|
||||
return logits
|
||||
|
||||
# === Abstract methods — subclass must implement ===
|
||||
|
||||
@abstractmethod
|
||||
def _init_state(self, cfg: AlignAttConfig):
|
||||
"""Initialize per-session decoder state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def init_tokens(self):
|
||||
"""Initialize token sequence with framework-specific tensors."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def init_context(self):
|
||||
"""Initialize context buffer with framework-specific TokenBuffer."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def insert_audio(self, segment=None):
|
||||
"""Insert audio segment into buffer."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _current_tokens(self):
|
||||
"""Build current token tensor for decoding."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def fire_at_boundary(self, feature):
|
||||
"""Check if we should fire at word boundary."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def lang_id(self, encoder_features):
|
||||
"""Language detection from encoder features. Returns (tokens, probs)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _concat_segments(self):
|
||||
"""Concatenate audio segments into single array/tensor."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _encode(self, input_segments):
|
||||
"""Encode audio. Returns (encoder_feature, content_mel_len)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _init_sum_logprobs(self):
|
||||
"""Create zero sum_logprobs tensor for beam search."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||
"""Get logits and cross-attention from decoder. Returns (logits, cross_attns)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _check_no_speech(self, logits):
|
||||
"""Check no_speech probability at start of segment. Returns True to break."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _suppress_blank_tokens(self, logits):
|
||||
"""Suppress blank/EOT tokens at segment start. Returns modified logits."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _apply_token_suppression(self, logits):
|
||||
"""Apply general token suppression. Returns modified logits."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||
"""Update tokens via decoder. Returns (current_tokens, completed)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _process_cross_attention(self, accumulated_cross_attns, content_mel_len):
|
||||
"""Process cross-attention for alignment. Returns attention tensor."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_attended_frames(self, attn):
|
||||
"""Get most attended frames. Returns (frames_as_python_list, first_frame_int)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _is_special_token(self, current_tokens):
|
||||
"""Check if second-to-last token is a special token (>= DEC_PAD)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _rewind_tokens(self):
|
||||
"""Concatenate state tokens for rewind. Returns token tensor."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _tokens_to_list(self, current_tokens, start_col):
|
||||
"""Extract tokens as Python list from start_col onwards."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _make_new_tokens_tensor(self, hypothesis):
|
||||
"""Create tensor from hypothesis token list, repeated for beam search."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _evaluate(self, tensor):
|
||||
"""Evaluate lazy tensor (mx.eval for MLX, no-op for PyTorch)."""
|
||||
...
|
||||
368
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
from .mlx import MLXAlignAtt
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
MLXAlignAtt = None
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
"""Online processor for SimulStreaming ASR."""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, asr, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
self.model.state.tokenizer = asr.tokenizer
|
||||
|
||||
def _create_alignatt(self):
|
||||
"""Create the AlignAtt decoder instance based on ASR mode."""
|
||||
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
|
||||
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
|
||||
else:
|
||||
return AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=self.asr.shared_model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def start_silence(self):
|
||||
tokens, processed_upto = self.process_iter(is_last=True)
|
||||
return tokens, processed_upto
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
"""Handle silence period."""
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(16000 * silence_duration)
|
||||
if gap_len > 0:
|
||||
if self.asr.use_full_mlx:
|
||||
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||
else:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
self.model.insert_audio(gap_silence)
|
||||
if long_silence:
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
self.end = audio_stream_end_time
|
||||
if self.asr.use_full_mlx:
|
||||
self.model.insert_audio(audio)
|
||||
else:
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
"""Handle speaker change event."""
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.model.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming processing error: {e}")
|
||||
return [], self.end
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
if self.asr.use_full_mlx:
|
||||
# MLX mode: ensure numpy array
|
||||
if hasattr(audio, 'numpy'):
|
||||
audio = audio.numpy()
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
gc.collect()
|
||||
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class SimulStreamingASR:
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.decoder_type is None:
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
|
||||
model_info = detect_model_format(resolved_model_path)
|
||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||
|
||||
if not self.use_full_mlx and not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
|
||||
elif self.model_size is not None:
|
||||
self.model_name = self.model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
|
||||
self.encoder_backend = self._resolve_encoder_backend(
|
||||
preferred_backend,
|
||||
compatible_whisper_mlx,
|
||||
compatible_faster_whisper,
|
||||
)
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
# MLX full decoder disabled by default — MLXAlignAtt has known issues
|
||||
# with token generation after punctuation. Users can opt-in with
|
||||
# --use-full-mlx if they want to test it.
|
||||
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||
# if not hasattr(self, '_full_mlx_disabled'):
|
||||
# self.use_full_mlx = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.lan,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task="translate" if self.direct_english_translation else "transcribe",
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.direct_english_translation:
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
||||
self.shared_model = None
|
||||
|
||||
if self.use_full_mlx and HAS_MLX_WHISPER:
|
||||
logger.info('MLX Whisper backend used.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model_path = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model_path:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
|
||||
self._warmup_mlx_model()
|
||||
elif self.encoder_backend == "mlx-whisper":
|
||||
# hybrid mode: mlx encoder + pytorch decoder
|
||||
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model_path = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model_path:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||
self.shared_model = self.load_model()
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
self.shared_model = self.load_model()
|
||||
else:
|
||||
self.shared_model = self.load_model()
|
||||
|
||||
def _warmup_mlx_model(self):
|
||||
"""Warmup the full MLX model."""
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
temp_model = MLXAlignAtt(
|
||||
cfg=self.cfg,
|
||||
mlx_model=self.mlx_model,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
logger.info("Full MLX model warmed up successfully")
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
choice = preferred_backend or "auto"
|
||||
if self.disable_fast_encoder:
|
||||
return "whisper"
|
||||
if choice == "whisper":
|
||||
return "whisper"
|
||||
if choice == "mlx-whisper":
|
||||
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||
return "mlx-whisper"
|
||||
if choice == "faster-whisper":
|
||||
if not self._can_use_faster(compatible_faster_whisper):
|
||||
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||
return "faster-whisper"
|
||||
if choice == "openai-api":
|
||||
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||
# auto mode
|
||||
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||
return "mlx-whisper"
|
||||
if self._can_use_faster(compatible_faster_whisper):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
def _has_custom_model_path(self):
|
||||
return self._resolved_model_path is not None
|
||||
|
||||
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||
if not HAS_MLX_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_whisper_mlx
|
||||
return self.model_name in mlx_model_mapping
|
||||
|
||||
def _can_use_faster(self, compatible_faster_whisper):
|
||||
if not HAS_FASTER_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_faster_whisper
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
|
||||
lora_path = getattr(self, 'lora_path', None)
|
||||
whisper_model = load_model(
|
||||
name=model_ref,
|
||||
download_root=getattr(self, 'model_cache_dir', None),
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = AlignAtt(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
fw_encoder=self.fw_encoder,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
else:
|
||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
if self.cfg.language == 'auto':
|
||||
raise ValueError('Translation cannot be done with language = auto')
|
||||
return tokenizer.get_tokenizer(
|
||||
multilingual=True,
|
||||
language=self.cfg.language,
|
||||
num_languages=99,
|
||||
task="translate"
|
||||
)
|
||||
|
||||
def transcribe(self, audio):
|
||||
"""
|
||||
Warmup is done directly in load_model
|
||||
"""
|
||||
pass
|
||||
@@ -1,17 +1,32 @@
|
||||
from .whisper.decoding import PyTorchInference
|
||||
from torch import Tensor
|
||||
|
||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||
|
||||
|
||||
# extention of PyTorchInference for beam search
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
"""Extension of PyTorchInference for beam search with cross-attention support."""
|
||||
|
||||
def _kv_modules(self):
|
||||
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
||||
return key_modules + value_modules
|
||||
def _kv_cache_ids(self):
|
||||
"""Get cache_id strings for self-attention key/value modules."""
|
||||
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
|
||||
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
|
||||
return key_ids + value_ids
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
for module_cache_id in self._kv_modules():
|
||||
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
||||
from torch import Tensor
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
for cache_id in self._kv_cache_ids():
|
||||
if cache_id in self.kv_cache:
|
||||
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: Tensor,
|
||||
audio_features: Tensor,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
"""Get logits, optionally returning cross-attention weights."""
|
||||
return self.model.decoder(
|
||||
tokens, audio_features,
|
||||
kv_cache=self.kv_cache,
|
||||
return_cross_attn=return_cross_attn,
|
||||
)
|
||||
@@ -1,29 +1,24 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
@dataclass
|
||||
class SimulWhisperConfig:
|
||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||
model_path: str
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 1.0
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig(SimulWhisperConfig):
|
||||
'''Options specific to the AlignAtt policy.'''
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
rewind_threshold: int = 200
|
||||
audio_max_len: float = 30.0
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
never_fire: bool = False
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
tokenizer_is_multilingual: bool = False
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
97
whisperlivekit/simul_whisper/decoder_state.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecoderState:
|
||||
|
||||
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
|
||||
tokenizer: Any = None
|
||||
detected_language: Optional[str] = None
|
||||
reset_tokenizer_to_auto_next_call: bool = False
|
||||
|
||||
tokens: List[torch.Tensor] = field(default_factory=list)
|
||||
initial_tokens: Optional[torch.Tensor] = None
|
||||
initial_token_length: int = 0
|
||||
sot_index: int = 0
|
||||
|
||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||
num_align_heads: int = 0
|
||||
|
||||
segments: List[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
context: Any = None
|
||||
|
||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||
pending_retries: int = 0
|
||||
|
||||
global_time_offset: float = 0.0
|
||||
cumulative_time_offset: float = 0.0
|
||||
first_timestamp: Optional[float] = None
|
||||
last_attend_frame: int = 0
|
||||
|
||||
speaker: int = -1
|
||||
log_segments: int = 0
|
||||
|
||||
CIFLinear: Optional[torch.nn.Module] = None
|
||||
always_fire: bool = False
|
||||
never_fire: bool = False
|
||||
|
||||
suppress_tokens_fn: Any = None
|
||||
|
||||
token_decoder: Any = None
|
||||
decoder_type: str = "greedy"
|
||||
|
||||
inference: Any = None
|
||||
|
||||
def clean_cache(self):
|
||||
"""Clean the kv_cache after each inference step."""
|
||||
# Explicitly delete tensor references to free GPU memory
|
||||
if self.kv_cache:
|
||||
for key in list(self.kv_cache.keys()):
|
||||
tensor = self.kv_cache.pop(key, None)
|
||||
if tensor is not None:
|
||||
del tensor
|
||||
|
||||
# Clear the dict
|
||||
self.kv_cache.clear()
|
||||
|
||||
# Force GPU cache cleanup (only if CUDA is available)
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.decoder_type == "beam" and self.inference is not None:
|
||||
# Create NEW dict instead of sharing reference
|
||||
self.inference.kv_cache = {}
|
||||
if self.token_decoder is not None:
|
||||
self.token_decoder.reset()
|
||||
|
||||
def reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Reset transient state for a new segment.
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
self.last_attend_frame = -rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.pending_incomplete_tokens = []
|
||||
self.pending_retries = 0
|
||||
self.log_segments += 1
|
||||
|
||||
def full_reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Full reset including audio segments and tokens.
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
self.reset(rewind_threshold)
|
||||
self.segments = []
|
||||
self.tokens = []
|
||||
self.kv_cache = {}
|
||||
self.first_timestamp = None
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
||||
|
||||
SimulStreaming is dual-licensed:
|
||||
|
||||
🔹 Non-Commercial Use
|
||||
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
||||
obtain the code through the GitHub repository. This license is **free of charge**
|
||||
and comes with **no obligations** for non-commercial users.
|
||||
|
||||
🔸 Commercial Use
|
||||
|
||||
Understanding who uses SimulStreaming commercially helps us improve and
|
||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
||||
|
||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
||||
are considering to provide commercial licenses either for free or for symbolic
|
||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
||||
available.
|
||||
|
||||
✉️ Contact
|
||||
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
@@ -1,40 +0,0 @@
|
||||
class Tokens:
|
||||
def __init__(self, tokens):
|
||||
self.tokens = tokens
|
||||
|
||||
# def clone(self):
|
||||
# return Tokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.tokens.tolist())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class BeamTokens(Tokens):
|
||||
def __init__(self, tokens, beam_size):
|
||||
self.tokens = tokens
|
||||
self.beam_size = beam_size
|
||||
|
||||
def clone(self):
|
||||
return BeamTokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class Logits(Tokens):
|
||||
def __init__(self, logits):
|
||||
super().__init__(logits)
|
||||
|
||||
# def clone(self):
|
||||
# return Logits(self.tokens.clone(), self.beam_size)
|
||||
|
||||
def __str__(self):
|
||||
# return "abc"
|
||||
return f"Logits({self.tokens.shape})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
11
whisperlivekit/simul_whisper/mlx/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .decoder_state import MLXDecoderState
|
||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||
from .simul_whisper import MLXAlignAtt
|
||||
|
||||
__all__ = [
|
||||
"MLXAlignAtt",
|
||||
"MLXBeamSearchDecoder",
|
||||
"MLXDecoderState",
|
||||
"MLXGreedyDecoder",
|
||||
"MLXInference",
|
||||
]
|
||||
78
whisperlivekit/simul_whisper/mlx/decoder_state.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLXDecoderState:
|
||||
"""
|
||||
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
|
||||
where each element is a tuple of mx.arrays.
|
||||
"""
|
||||
|
||||
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
||||
|
||||
tokenizer: Any = None
|
||||
detected_language: Optional[str] = None
|
||||
reset_tokenizer_to_auto_next_call: bool = False
|
||||
|
||||
tokens: List[mx.array] = field(default_factory=list)
|
||||
initial_tokens: Optional[mx.array] = None
|
||||
initial_token_length: int = 0
|
||||
sot_index: int = 0
|
||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||
num_align_heads: int = 0
|
||||
segments: List[np.ndarray] = field(default_factory=list)
|
||||
|
||||
context: Any = None
|
||||
|
||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||
pending_retries: int = 0
|
||||
|
||||
global_time_offset: float = 0.0
|
||||
cumulative_time_offset: float = 0.0
|
||||
first_timestamp: Optional[float] = None
|
||||
last_attend_frame: int = 0
|
||||
|
||||
speaker: int = -1
|
||||
log_segments: int = 0
|
||||
cif_weights: Optional[mx.array] = None
|
||||
always_fire: bool = False
|
||||
never_fire: bool = False
|
||||
|
||||
suppress_tokens: Optional[Tuple[int, ...]] = None
|
||||
|
||||
token_decoder: Any = None
|
||||
decoder_type: str = "greedy"
|
||||
|
||||
inference: Any = None
|
||||
|
||||
def clean_cache(self):
|
||||
self.kv_cache = None
|
||||
if self.decoder_type == "beam" and self.inference is not None:
|
||||
self.inference.kv_cache = None
|
||||
if self.token_decoder is not None:
|
||||
self.token_decoder.reset()
|
||||
|
||||
def reset(self, rewind_threshold: int = 200):
|
||||
self.last_attend_frame = -rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.pending_incomplete_tokens = []
|
||||
self.pending_retries = 0
|
||||
self.log_segments += 1
|
||||
|
||||
def full_reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Full reset including audio segments and tokens.
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
self.reset(rewind_threshold)
|
||||
self.segments = []
|
||||
self.tokens = []
|
||||
self.kv_cache = None
|
||||
self.first_timestamp = None
|
||||
|
||||
219
whisperlivekit/simul_whisper/mlx/decoders.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
MLX-native token decoders for streaming ASR.
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXGreedyDecoder:
|
||||
"""Greedy decoder using MLX operations."""
|
||||
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(
|
||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||
) -> Tuple[mx.array, bool]:
|
||||
"""
|
||||
Update tokens with next predicted token.
|
||||
|
||||
Args:
|
||||
tokens: Current token sequence, shape (batch, seq_len)
|
||||
logits: Logits for next token, shape (batch, vocab_size)
|
||||
sum_logprobs: Cumulative log probabilities, shape (batch,)
|
||||
|
||||
Returns:
|
||||
Updated tokens and completion flag
|
||||
"""
|
||||
if self.temperature == 0:
|
||||
next_tokens = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
probs = mx.softmax(logits / self.temperature, axis=-1)
|
||||
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
|
||||
|
||||
logprobs = mx.softmax(logits, axis=-1)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
batch_size = logprobs.shape[0]
|
||||
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
||||
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
|
||||
sum_logprobs = sum_logprobs + current_logprobs * mask
|
||||
eot_mask = (tokens[:, -1] == self.eot)
|
||||
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
||||
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
||||
completed = bool(mx.all(tokens[:, -1] == self.eot))
|
||||
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||
"""Finalize decoding by ensuring EOT at end."""
|
||||
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
|
||||
tokens = mx.concatenate([tokens, eot_column], axis=1)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class MLXBeamSearchDecoder:
|
||||
"""Beam search decoder using MLX operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
eot: int,
|
||||
inference: Any,
|
||||
patience: Optional[float] = None,
|
||||
):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences: Optional[List[Dict]] = None
|
||||
|
||||
assert (
|
||||
self.max_candidates > 0
|
||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
"""Reset finished sequences for new segment."""
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(
|
||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||
) -> Tuple[mx.array, bool]:
|
||||
"""
|
||||
Update tokens using beam search.
|
||||
|
||||
Args:
|
||||
tokens: Current token sequences, shape (batch * beam_size, seq_len)
|
||||
logits: Logits for next token, shape (batch * beam_size, vocab_size)
|
||||
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
|
||||
|
||||
Returns:
|
||||
Updated tokens and completion flag
|
||||
"""
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None:
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
logprobs = mx.softmax(logits, axis=-1)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
logprobs_np = np.array(logprobs)
|
||||
tokens_np = np.array(tokens)
|
||||
sum_logprobs_np = np.array(sum_logprobs)
|
||||
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
new_sum_logprobs = []
|
||||
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens_np[idx].tolist()
|
||||
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
|
||||
|
||||
for token_idx in top_k_indices:
|
||||
logprob = logprobs_np[idx, token_idx]
|
||||
new_logprob = sum_logprobs_np[idx] + logprob
|
||||
sequence = tuple(prefix + [int(token_idx)])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
new_sum_logprobs.append(scores[sequence])
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
|
||||
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(
|
||||
self.finished_sequences, finished_sequences
|
||||
):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
|
||||
"""Finalize beam search by selecting best sequences."""
|
||||
preceding_tokens_np = np.array(preceding_tokens)
|
||||
sum_logprobs_np = np.array(sum_logprobs)
|
||||
|
||||
n_audio = preceding_tokens_np.shape[0] // self.beam_size
|
||||
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
|
||||
sum_logprobs_list: List[float] = [0.0] * n_audio
|
||||
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if sequences:
|
||||
best_seq = max(sequences, key=sequences.get)
|
||||
tokens_list[i] = list(best_seq)
|
||||
sum_logprobs_list[i] = sequences[best_seq]
|
||||
else:
|
||||
idx = i * self.beam_size
|
||||
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
|
||||
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
|
||||
max_len = max(len(t) for t in tokens_list)
|
||||
for i, t in enumerate(tokens_list):
|
||||
tokens_list[i] = t + [self.eot] * (max_len - len(t))
|
||||
|
||||
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
|
||||
return tokens, sum_logprobs_list
|
||||
|
||||
|
||||
class MLXInference:
|
||||
"""MLX inference wrapper for beam search KV cache management."""
|
||||
|
||||
def __init__(self, model, initial_token_length: int):
|
||||
self.model = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = None
|
||||
|
||||
def rearrange_kv_cache(self, source_indices: List[int]):
|
||||
"""Rearrange KV cache based on beam search source indices."""
|
||||
if self.kv_cache is None:
|
||||
return
|
||||
|
||||
if source_indices == list(range(len(source_indices))):
|
||||
return
|
||||
|
||||
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
|
||||
|
||||
new_cache = []
|
||||
for layer_cache in self.kv_cache:
|
||||
(k, v), (cross_k, cross_v) = layer_cache
|
||||
new_k = k[source_indices_mx]
|
||||
new_v = v[source_indices_mx]
|
||||
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
|
||||
|
||||
self.kv_cache = new_cache
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
audio_features: mx.array,
|
||||
) -> Tuple[mx.array, List]:
|
||||
"""Get logits from decoder with KV cache."""
|
||||
logits, self.kv_cache, cross_qk = self.model.decoder(
|
||||
tokens, audio_features, kv_cache=self.kv_cache
|
||||
)
|
||||
return logits, cross_qk
|
||||
|
||||
421
whisperlivekit/simul_whisper/mlx/simul_whisper.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""MLX whisper AlignAtt streaming decoder."""
|
||||
import logging
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
|
||||
|
||||
from ..align_att_base import DEC_PAD, AlignAttBase
|
||||
from ..config import AlignAttConfig
|
||||
from .decoder_state import MLXDecoderState
|
||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLXTokenBuffer:
|
||||
"""Token buffer for MLX-based decoding."""
|
||||
|
||||
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
|
||||
self.text = text
|
||||
self.prefix_token_ids = prefix_token_ids or []
|
||||
self.tokenizer = tokenizer
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_mlx_array(self) -> mx.array:
|
||||
tok_ids = self.as_token_ids()
|
||||
return mx.array([tok_ids], dtype=mx.int32)
|
||||
|
||||
def as_mlx_array_beam(self, beam: int) -> mx.array:
|
||||
t = self.as_mlx_array()
|
||||
return mx.repeat(t, beam, axis=0)
|
||||
|
||||
def as_text(self):
|
||||
return self.text
|
||||
|
||||
@staticmethod
|
||||
def empty(*a, **kw):
|
||||
return MLXTokenBuffer(*a, **kw)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return MLXTokenBuffer(*a, text=text, **kw)
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
def trim_words(self, num=1, after=0):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
ids = tokenizer.encode(self.text[after:])
|
||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||
if not words:
|
||||
return 0
|
||||
self.text = self.text[:after] + "".join(words[num:])
|
||||
return sum(len(wi) for wi in wids[:num])
|
||||
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
all_tokens = self.pending_token_ids + token_ids
|
||||
decoded = tokenizer.decode(all_tokens)
|
||||
replacement_char = "\ufffd"
|
||||
if replacement_char in decoded:
|
||||
if len(all_tokens) > 1:
|
||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||
if replacement_char not in decoded_partial:
|
||||
self.text += decoded_partial
|
||||
self.pending_token_ids = [all_tokens[-1]]
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.text += decoded
|
||||
self.pending_token_ids = []
|
||||
|
||||
|
||||
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
|
||||
"""Apply median filter along the last axis."""
|
||||
if filter_width <= 1:
|
||||
return x
|
||||
pad_width = filter_width // 2
|
||||
shape = x.shape
|
||||
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
|
||||
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
|
||||
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
|
||||
result = []
|
||||
for i in range(shape[-1]):
|
||||
window = x_padded[..., i:i + filter_width]
|
||||
sorted_window = mx.sort(window, axis=-1)
|
||||
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
|
||||
result.append(median_val)
|
||||
return mx.concatenate(result, axis=-1)
|
||||
|
||||
|
||||
class MLXAlignAtt(AlignAttBase):
|
||||
"""
|
||||
MLX-native Alignment-based Attention decoder for SimulStreaming.
|
||||
|
||||
Runs entirely on MLX, with no PyTorch dependencies for inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
mlx_model: Any,
|
||||
) -> None:
|
||||
# Common init (sets self.model, self.cfg, decode_options, etc.)
|
||||
self._base_init(cfg, mlx_model)
|
||||
logger.info(f"MLX Model dimensions: {self.model.dims}")
|
||||
|
||||
# Per-session state
|
||||
self.state = MLXDecoderState()
|
||||
self._init_state(cfg)
|
||||
|
||||
def _init_state(self, cfg: AlignAttConfig):
|
||||
self._init_state_common(cfg)
|
||||
|
||||
# CIF: MLX doesn't support CIF checkpoint loading
|
||||
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
||||
if cfg.never_fire:
|
||||
self.state.never_fire = True
|
||||
self.state.always_fire = False
|
||||
else:
|
||||
self.state.always_fire = True
|
||||
self.state.never_fire = False
|
||||
else:
|
||||
logger.warning(
|
||||
"CIF checkpoint provided but MLX CIF not implemented. "
|
||||
"Using always_fire=True"
|
||||
)
|
||||
self.state.always_fire = True
|
||||
self.state.never_fire = cfg.never_fire
|
||||
|
||||
self._build_alignment_source()
|
||||
|
||||
# Suppress tokens
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe, self.tokenizer.translate,
|
||||
self.tokenizer.sot, self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
|
||||
] + list(self.tokenizer.all_language_tokens)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
|
||||
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
|
||||
# Decoder type
|
||||
self.state.decoder_type = cfg.decoder_type
|
||||
if cfg.decoder_type == "greedy":
|
||||
logger.info("Using MLX greedy decoder")
|
||||
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
|
||||
elif cfg.decoder_type == "beam":
|
||||
logger.info("Using MLX beam decoder")
|
||||
self.state.inference = MLXInference(
|
||||
self.model, self.state.initial_token_length,
|
||||
)
|
||||
self.state.token_decoder = MLXBeamSearchDecoder(
|
||||
inference=self.state.inference,
|
||||
eot=self.tokenizer.eot,
|
||||
beam_size=cfg.beam_size,
|
||||
)
|
||||
|
||||
def _build_alignment_source(self):
|
||||
"""Build alignment source mapping from model's alignment_heads."""
|
||||
self.state.align_source = {}
|
||||
self.state.num_align_heads = 0
|
||||
alignment_heads = self.model.alignment_heads
|
||||
if alignment_heads is None:
|
||||
logger.warning("No alignment heads found in model")
|
||||
return
|
||||
if hasattr(alignment_heads, 'tolist'):
|
||||
heads_list = alignment_heads.tolist()
|
||||
else:
|
||||
heads_list = np.array(alignment_heads).tolist()
|
||||
for layer_rank, head_id in heads_list:
|
||||
layer_rank = int(layer_rank)
|
||||
head_id = int(head_id)
|
||||
heads = self.state.align_source.get(layer_rank, [])
|
||||
heads.append((self.state.num_align_heads, head_id))
|
||||
self.state.align_source[layer_rank] = heads
|
||||
self.state.num_align_heads += 1
|
||||
|
||||
# === Abstract method implementations ===
|
||||
|
||||
def init_tokens(self):
|
||||
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||
self.state.initial_tokens = mx.array(
|
||||
[self.tokenizer.sot_sequence_including_notimestamps],
|
||||
dtype=mx.int32,
|
||||
)
|
||||
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||
self.state.tokens = [self.state.initial_tokens]
|
||||
|
||||
def init_context(self):
|
||||
kw = {
|
||||
'tokenizer': self.tokenizer,
|
||||
'prefix_token_ids': [self.tokenizer.sot_prev],
|
||||
}
|
||||
self.state.context = MLXTokenBuffer.empty(**kw)
|
||||
if self.cfg.static_init_prompt is not None:
|
||||
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
if self.cfg.init_prompt is not None:
|
||||
self.state.context.text += self.cfg.init_prompt
|
||||
|
||||
def insert_audio(self, segment=None):
|
||||
if segment is not None:
|
||||
if hasattr(segment, 'numpy'):
|
||||
segment = segment.numpy()
|
||||
self.state.segments.append(segment)
|
||||
removed_len = 0
|
||||
segments_len = self.segments_len()
|
||||
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.state.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||
self.state.cumulative_time_offset += removed_len
|
||||
self.state.segments = self.state.segments[1:]
|
||||
logger.debug(
|
||||
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
|
||||
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
|
||||
)
|
||||
if len(self.state.tokens) > 1:
|
||||
token_list = np.array(self.state.tokens[1][0, :]).tolist()
|
||||
self.state.context.append_token_ids(token_list)
|
||||
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
def _current_tokens(self) -> mx.array:
|
||||
toks = self.state.tokens
|
||||
if toks[0].shape[0] == 1:
|
||||
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
|
||||
if not self.state.context.is_empty():
|
||||
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
|
||||
toks = [context_toks] + toks
|
||||
if len(toks) > 1:
|
||||
current_tokens = mx.concatenate(toks, axis=1)
|
||||
else:
|
||||
current_tokens = toks[0]
|
||||
logger.debug("debug print current_tokens:")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
return current_tokens
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
|
||||
if self.state.always_fire:
|
||||
return True
|
||||
if self.state.never_fire:
|
||||
return False
|
||||
return True # MLX CIF not implemented
|
||||
|
||||
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
|
||||
n_audio = encoder_features.shape[0]
|
||||
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
|
||||
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
|
||||
logits = logits[:, 0]
|
||||
|
||||
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
|
||||
language_token_indices = mx.array(
|
||||
list(self.tokenizer.all_language_tokens), dtype=mx.int32,
|
||||
)
|
||||
mask = mask.at[language_token_indices].add(False)
|
||||
logits = mx.where(mask, mx.array(-float('inf')), logits)
|
||||
|
||||
language_tokens = mx.argmax(logits, axis=-1)
|
||||
language_token_probs = mx.softmax(logits, axis=-1)
|
||||
probs_np = np.array(language_token_probs)
|
||||
language_probs = [
|
||||
{
|
||||
c: float(probs_np[i, j])
|
||||
for j, c in zip(
|
||||
self.tokenizer.all_language_tokens,
|
||||
self.tokenizer.all_language_codes,
|
||||
)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
self._clean_cache()
|
||||
return language_tokens, language_probs
|
||||
|
||||
def _concat_segments(self):
|
||||
if len(self.state.segments) > 1:
|
||||
return np.concatenate(self.state.segments, axis=0)
|
||||
return self.state.segments[0]
|
||||
|
||||
def _encode(self, input_segments):
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(
|
||||
audio=input_segments,
|
||||
n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES,
|
||||
)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
encoder_feature = self.model.encoder(mlx_mel[None])
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
||||
return encoder_feature, content_mel_len
|
||||
|
||||
def _init_sum_logprobs(self):
|
||||
return mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
|
||||
|
||||
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||
if self.state.decoder_type == "greedy":
|
||||
logits, self.state.kv_cache, cross_qk = self.model.decoder(
|
||||
tokens, encoder_feature, kv_cache=self.state.kv_cache,
|
||||
)
|
||||
return logits, cross_qk
|
||||
else:
|
||||
return self.state.inference.logits(tokens, encoder_feature)
|
||||
|
||||
def _check_no_speech(self, logits):
|
||||
if self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||
no_speech_probs = np.array(
|
||||
probs_at_sot[:, self.tokenizer.no_speech],
|
||||
).tolist()
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
logger.info("no speech, stop")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _suppress_blank_tokens(self, logits):
|
||||
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
|
||||
logits = logits.at[:, blank_tokens].add(-float('inf'))
|
||||
return logits
|
||||
|
||||
def _apply_token_suppression(self, logits):
|
||||
if self.state.suppress_tokens:
|
||||
suppress_indices = mx.array(
|
||||
list(self.state.suppress_tokens), dtype=mx.int32,
|
||||
)
|
||||
logits = logits.at[:, suppress_indices].add(-float('inf'))
|
||||
return logits
|
||||
|
||||
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
|
||||
def _process_cross_attention(
|
||||
self, cross_attns: List, content_mel_len: int,
|
||||
) -> mx.array:
|
||||
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||
num_decoder_layers = self.num_decoder_layers
|
||||
|
||||
if cross_attns and isinstance(cross_attns[0], list):
|
||||
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
||||
else:
|
||||
flattened_attns = cross_attns
|
||||
|
||||
for idx, attn_mat in enumerate(flattened_attns):
|
||||
if attn_mat is None:
|
||||
continue
|
||||
layer_rank = idx % num_decoder_layers
|
||||
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||
if not align_heads_in_layer:
|
||||
continue
|
||||
attn_mat = mx.softmax(attn_mat, axis=-1)
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
if attn_mat.ndim == 4:
|
||||
a = attn_mat[0, head_id, :, :]
|
||||
else:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a[None, :, :]
|
||||
else:
|
||||
a = attn_mat[:, head_id, :, :]
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
if mat:
|
||||
tmp.append(mx.concatenate(mat, axis=1))
|
||||
if not tmp:
|
||||
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
|
||||
|
||||
attn_of_alignment_heads = mx.stack(tmp, axis=1)
|
||||
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
|
||||
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||
mx.eval(attn_of_alignment_heads)
|
||||
return attn_of_alignment_heads
|
||||
|
||||
def _get_attended_frames(self, attn):
|
||||
most_attended_frames = mx.argmax(attn[:, -1, :], axis=-1)
|
||||
frames_np = np.array(most_attended_frames)
|
||||
return frames_np.tolist(), int(frames_np[0])
|
||||
|
||||
def _is_special_token(self, current_tokens):
|
||||
return int(np.array(current_tokens[0, -2])) >= DEC_PAD
|
||||
|
||||
def _rewind_tokens(self):
|
||||
if len(self.state.tokens) > 0:
|
||||
return mx.concatenate(self.state.tokens, axis=1)
|
||||
return self.state.tokens[0]
|
||||
|
||||
def _tokens_to_list(self, current_tokens, start_col):
|
||||
return np.array(current_tokens[0, start_col:]).tolist()
|
||||
|
||||
def _make_new_tokens_tensor(self, hypothesis):
|
||||
new_tokens = mx.array([hypothesis], dtype=mx.int32)
|
||||
return mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
|
||||
|
||||
def _evaluate(self, tensor):
|
||||
mx.eval(tensor)
|
||||
95
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from mlx_whisper import whisper
|
||||
|
||||
from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
|
||||
|
||||
mlx_model_mapping = MLX_MODEL_MAPPING
|
||||
|
||||
def load_mlx_encoder(
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
# we only want to load the encoder weights here.
|
||||
# Size examples: for tiny.en,
|
||||
# Decoder weights: 59110771 bytes
|
||||
# Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
encoder_weights = {}
|
||||
encoder_weights['encoder'] = weights['encoder']
|
||||
del(weights)
|
||||
|
||||
|
||||
|
||||
model.update(encoder_weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
def load_mlx_model(
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
@@ -1,227 +1,198 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||
TOKENS_PER_SECOND,
|
||||
log_mel_spectrogram, pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
||||
SuppressTokens)
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
|
||||
from .align_att_base import DEC_PAD, AlignAttBase
|
||||
from .beam import BeamPyTorchInference
|
||||
from .config import AlignAttConfig
|
||||
from .decoder_state import DecoderState
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
|
||||
|
||||
import numpy as np
|
||||
from .generation_progress import *
|
||||
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import sys
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import \
|
||||
log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
# - translation support
|
||||
# - beam search
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(self, cfg: AlignAttConfig) -> None:
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
|
||||
USE_MLCORE = False
|
||||
|
||||
|
||||
def load_coreml_encoder():
|
||||
try:
|
||||
from coremltools.models import MLModel
|
||||
except ImportError:
|
||||
logger.warning("coremltools is not installed")
|
||||
return None
|
||||
COREML_ENCODER_PATH = os.environ.get(
|
||||
"MLCORE_ENCODER_PATH",
|
||||
"whisperlivekit/whisper/whisper_encoder.mlpackage",
|
||||
)
|
||||
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
||||
spec = _coreml_encoder.get_spec()
|
||||
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
||||
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
|
||||
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
||||
|
||||
|
||||
class AlignAtt(AlignAttBase):
|
||||
"""
|
||||
PyTorch Alignment-based Attention decoder for SimulStreaming.
|
||||
|
||||
Hookless — the model can be shared across multiple sessions,
|
||||
with each session maintaining its own DecoderState.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
loaded_model=None,
|
||||
mlx_encoder=None,
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(
|
||||
feature_size=loaded_model.dims.n_mels,
|
||||
)
|
||||
self.coreml_encoder_tuple = None
|
||||
if USE_MLCORE:
|
||||
self.coreml_encoder_tuple = load_coreml_encoder()
|
||||
self.use_mlcore = self.coreml_encoder_tuple is not None
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Common init (sets self.model, self.cfg, decode_options, etc.)
|
||||
self._base_init(cfg, loaded_model)
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
# Per-session state
|
||||
self.state = DecoderState()
|
||||
self._init_state(cfg)
|
||||
|
||||
def _init_state(self, cfg: AlignAttConfig):
|
||||
self._init_state_common(cfg)
|
||||
|
||||
# CIF helpers for end-of-word boundary detection
|
||||
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
|
||||
cfg, n_audio_state=self.model.dims.n_audio_state, device=self.model.device,
|
||||
)
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=not model_name.endswith(".en"),
|
||||
language=cfg.language,
|
||||
num_languages=self.model.num_languages,
|
||||
task=decode_options.task
|
||||
)
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
self.cfg = cfg
|
||||
|
||||
|
||||
# model to detect end-of-word boundary at the end of the segment
|
||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||
n_audio_state=self.model.dims.n_audio_state,
|
||||
device=self.model.device)
|
||||
|
||||
# install hooks to access encoder-decoder attention
|
||||
self.dec_attns = []
|
||||
def layer_hook(module, net_input, net_output):
|
||||
# net_output[1]: B*num_head*token_len*audio_len
|
||||
t = F.softmax(net_output[1], dim=-1)
|
||||
self.dec_attns.append(t.squeeze(0))
|
||||
for b in self.model.decoder.blocks:
|
||||
b.cross_attn.register_forward_hook(layer_hook)
|
||||
|
||||
self.kv_cache = {}
|
||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
||||
# save as-is, for the first token or cross attention
|
||||
self.kv_cache[module.cache_id] = net_output
|
||||
else:
|
||||
x = self.kv_cache[module.cache_id]
|
||||
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
||||
return self.kv_cache[module.cache_id]
|
||||
|
||||
for i,b in enumerate(self.model.decoder.blocks):
|
||||
b.attn.key.register_forward_hook(kv_hook)
|
||||
b.attn.value.register_forward_hook(kv_hook)
|
||||
b.cross_attn.key.register_forward_hook(kv_hook)
|
||||
b.cross_attn.value.register_forward_hook(kv_hook)
|
||||
|
||||
self.align_source = {}
|
||||
self.num_align_heads = 0
|
||||
# Build alignment source mapping
|
||||
self.state.align_source = {}
|
||||
self.state.num_align_heads = 0
|
||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||
layer_rank = layer_rank.item()
|
||||
heads = self.align_source.get(layer_rank, [])
|
||||
heads.append((self.num_align_heads, head_id.item()))
|
||||
self.align_source[layer_rank] = heads
|
||||
self.num_align_heads += 1
|
||||
heads = self.state.align_source.get(layer_rank, [])
|
||||
heads.append((self.state.num_align_heads, head_id.item()))
|
||||
self.state.align_source[layer_rank] = heads
|
||||
self.state.num_align_heads += 1
|
||||
|
||||
|
||||
# init tokens (mandatory prompt)
|
||||
self.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
self.initial_token_length = self.initial_tokens.shape[1]
|
||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
|
||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||
# Build suppress tokens function
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
# self.tokenizer.eot
|
||||
self.tokenizer.no_timestamps, # added by DM
|
||||
] + list(self.tokenizer.all_language_tokens) # added by DM
|
||||
self.tokenizer.transcribe, self.tokenizer.translate,
|
||||
self.tokenizer.sot, self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm, self.tokenizer.no_timestamps,
|
||||
] + list(self.tokenizer.all_language_tokens)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||
sup_tokens = SuppressTokens(suppress_tokens)
|
||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||
# blank tokens are suppresed for new segments near the line 334
|
||||
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
|
||||
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
|
||||
# decoder type: greedy or beam
|
||||
# Decoder type
|
||||
self.state.decoder_type = cfg.decoder_type
|
||||
if cfg.decoder_type == "greedy":
|
||||
logger.info("Using greedy decoder")
|
||||
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||
self.decoder_type = "greedy"
|
||||
|
||||
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||
elif cfg.decoder_type == "beam":
|
||||
self.decoder_type = "beam"
|
||||
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
logger.info("Using beam decoder")
|
||||
self.state.inference = BeamPyTorchInference(
|
||||
self.model, self.state.initial_token_length,
|
||||
)
|
||||
self.state.inference.kv_cache = self.state.kv_cache
|
||||
self.state.token_decoder = BeamSearchDecoder(
|
||||
inference=self.state.inference,
|
||||
eot=self.tokenizer.eot,
|
||||
beam_size=cfg.beam_size,
|
||||
)
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
# === Abstract method implementations ===
|
||||
|
||||
# init state
|
||||
self.segments = []
|
||||
self.tokens = [self.initial_tokens]
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = self.cfg.max_context_tokens
|
||||
self.init_context()
|
||||
def init_tokens(self):
|
||||
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||
self.state.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long, device=self.model.device,
|
||||
).unsqueeze(0)
|
||||
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||
self.state.tokens = [self.state.initial_tokens]
|
||||
|
||||
def init_context(self):
|
||||
kw = {'tokenizer': self.tokenizer,
|
||||
'device': self.model.device,
|
||||
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
||||
self.context = TokenBuffer.empty(**kw)
|
||||
kw = {
|
||||
'tokenizer': self.tokenizer,
|
||||
'device': self.model.device,
|
||||
'prefix_token_ids': [self.tokenizer.sot_prev],
|
||||
}
|
||||
self.state.context = TokenBuffer.empty(**kw)
|
||||
if self.cfg.static_init_prompt is not None:
|
||||
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
if self.cfg.init_prompt is not None:
|
||||
self.context.text += self.cfg.init_prompt
|
||||
|
||||
def trim_context(self):
|
||||
logger.info("Trimming context")
|
||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
||||
logger.info(f"Context text: {self.context.as_text()}")
|
||||
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
||||
l = sum(t.shape[1] for t in self.tokens) + c
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if self.cfg.static_init_prompt is None:
|
||||
after = 0
|
||||
else:
|
||||
after = len(self.cfg.static_init_prompt)
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||
t = self.context.trim_words(after=after)
|
||||
l -= t
|
||||
c -= t
|
||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if t == 0:
|
||||
break
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
||||
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
if self.cfg.decoder_type == "greedy":
|
||||
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
else:
|
||||
logger.debug(f"Logits shape: {tokens.shape}")
|
||||
logit = self.inference.logits(tokens, audio_features)
|
||||
return logit
|
||||
|
||||
|
||||
def refresh_segment(self, complete=False):
|
||||
|
||||
logger.debug("Refreshing segment")
|
||||
self.tokens = [self.initial_tokens]
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
self.segments = []
|
||||
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.always_fire: return True
|
||||
if self.never_fire: return False
|
||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||
|
||||
|
||||
self.state.context.text += self.cfg.init_prompt
|
||||
|
||||
def insert_audio(self, segment=None):
|
||||
if segment is not None:
|
||||
self.state.segments.append(segment)
|
||||
removed_len = 0
|
||||
segments_len = self.segments_len()
|
||||
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.state.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||
self.state.cumulative_time_offset += removed_len
|
||||
self.state.segments = self.state.segments[1:]
|
||||
logger.debug(
|
||||
f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, "
|
||||
f"cumulative offset: {self.state.cumulative_time_offset:.2f}s"
|
||||
)
|
||||
if len(self.state.tokens) > 1:
|
||||
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
|
||||
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
def _current_tokens(self):
|
||||
|
||||
toks = self.tokens
|
||||
# very first infer: duplicate start of seq to beam_size
|
||||
toks = self.state.tokens
|
||||
if toks[0].shape[0] == 1:
|
||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
||||
|
||||
if not self.context.is_empty():
|
||||
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
|
||||
if not self.state.context.is_empty():
|
||||
context_toks = self.state.context.as_tensor_beam(
|
||||
self.cfg.beam_size, device=self.model.device,
|
||||
)
|
||||
toks = [context_toks] + toks
|
||||
|
||||
# make it one tensor
|
||||
if len(toks) > 1:
|
||||
current_tokens = torch.cat(toks, dim=1)
|
||||
else:
|
||||
@@ -230,296 +201,208 @@ class PaddedAlignAttWhisper:
|
||||
self.debug_print_tokens(current_tokens)
|
||||
return current_tokens
|
||||
|
||||
|
||||
def debug_print_tokens(self, tokens):
|
||||
for i in range(self.cfg.beam_size):
|
||||
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
||||
|
||||
### audio buffer
|
||||
|
||||
def segments_len(self):
|
||||
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
||||
return segments_len
|
||||
|
||||
def _apply_minseglen(self):
|
||||
segments_len = self.segments_len()
|
||||
# wait for long enough audio to start
|
||||
if segments_len < self.cfg.audio_min_len:
|
||||
logger.debug("waiting for next segment")
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.state.always_fire:
|
||||
return True
|
||||
if self.state.never_fire:
|
||||
return False
|
||||
return True
|
||||
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
|
||||
|
||||
def insert_audio(self, segment=None):
|
||||
if segment is not None:
|
||||
self.segments.append(segment)
|
||||
@torch.no_grad()
|
||||
def lang_id(self, encoder_features):
|
||||
n_audio = encoder_features.shape[0]
|
||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device)
|
||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||
|
||||
removed_len = 0
|
||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||
segments_len = self.segments_len()
|
||||
while segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
||||
self.context.append_token_ids(self.tokens[1][0,:])
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
return removed_len
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(
|
||||
self.tokenizer.all_language_tokens,
|
||||
self.tokenizer.all_language_codes,
|
||||
)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
single = encoder_features.ndim == 2
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
self._clean_cache()
|
||||
return language_tokens, language_probs
|
||||
|
||||
def _concat_segments(self):
|
||||
if len(self.state.segments) > 1:
|
||||
return torch.cat(self.state.segments, dim=0)
|
||||
return self.state.segments[0]
|
||||
|
||||
### transcription / translation
|
||||
def _encode(self, input_segments):
|
||||
if self.use_mlcore:
|
||||
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments, n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES, device="cpu",
|
||||
).unsqueeze(0)
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||
mel_np = np.ascontiguousarray(mel.numpy())
|
||||
ml_inputs = {coreml_input_name or "mel": mel_np}
|
||||
coreml_outputs = coreml_encoder.predict(ml_inputs)
|
||||
if coreml_output_name and coreml_output_name in coreml_outputs:
|
||||
encoder_feature_np = coreml_outputs[coreml_output_name]
|
||||
else:
|
||||
encoder_feature_np = next(iter(coreml_outputs.values()))
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_np), device=self.device,
|
||||
)
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(
|
||||
audio=input_segments.detach(),
|
||||
n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
||||
elif self.fw_encoder:
|
||||
audio_length_seconds = len(input_segments) / 16000
|
||||
content_mel_len = int(audio_length_seconds * 100) // 2
|
||||
mel_padded_2 = self.fw_feature_extractor(
|
||||
waveform=input_segments.numpy(), padding=N_SAMPLES,
|
||||
)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
if self.device == 'cpu':
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError:
|
||||
# Some numpy/ctranslate2 versions produce object_ dtype arrays; force float32
|
||||
arr = np.array(encoder_feature_ctranslate)
|
||||
if arr.dtype == np.object_:
|
||||
arr = np.array(arr.tolist(), dtype=np.float32)
|
||||
encoder_feature = torch.as_tensor(arr, device=self.device)
|
||||
else:
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments, n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES, device=self.device,
|
||||
).unsqueeze(0)
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
return encoder_feature, content_mel_len
|
||||
|
||||
def _init_sum_logprobs(self):
|
||||
return torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
|
||||
def _get_logits_and_cross_attn(self, tokens, encoder_feature):
|
||||
if self.state.decoder_type == "greedy":
|
||||
return self.model.decoder(
|
||||
tokens, encoder_feature,
|
||||
kv_cache=self.state.kv_cache,
|
||||
return_cross_attn=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Logits shape: {tokens.shape}")
|
||||
return self.state.inference.logits(
|
||||
tokens, encoder_feature, return_cross_attn=True,
|
||||
)
|
||||
|
||||
def _check_no_speech(self, logits):
|
||||
if self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
logger.info("no speech, stop")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _suppress_blank_tokens(self, logits):
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
return logits
|
||||
|
||||
def _apply_token_suppression(self, logits):
|
||||
self.state.suppress_tokens_fn(logits)
|
||||
return logits
|
||||
|
||||
def _update_tokens(self, current_tokens, logits, sum_logprobs):
|
||||
return self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
|
||||
def _process_cross_attention(
|
||||
self, cross_attns: List, content_mel_len: int,
|
||||
) -> torch.Tensor:
|
||||
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||
num_decoder_layers = len(self.model.decoder.blocks)
|
||||
|
||||
if cross_attns and isinstance(cross_attns[0], list):
|
||||
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
||||
else:
|
||||
flattened_attns = cross_attns
|
||||
|
||||
for idx, attn_mat in enumerate(flattened_attns):
|
||||
layer_rank = idx % num_decoder_layers
|
||||
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||
if not align_heads_in_layer:
|
||||
continue
|
||||
attn_mat = F.softmax(attn_mat, dim=-1)
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
if attn_mat.dim() == 4:
|
||||
a = attn_mat[0, head_id, :, :]
|
||||
else:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a.unsqueeze(0)
|
||||
else:
|
||||
a = attn_mat[:, head_id, :, :]
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
if mat:
|
||||
tmp.append(torch.cat(mat, dim=1))
|
||||
if not tmp:
|
||||
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
|
||||
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
std, mean = torch.std_mean(
|
||||
attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False,
|
||||
)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||
return attn_of_alignment_heads
|
||||
|
||||
def _get_attended_frames(self, attn):
|
||||
most_attended_frames = torch.argmax(attn[:, -1, :], dim=-1)
|
||||
return most_attended_frames.tolist(), most_attended_frames[0].item()
|
||||
|
||||
def _is_special_token(self, current_tokens):
|
||||
return current_tokens[0, -2].item() >= DEC_PAD
|
||||
|
||||
def _rewind_tokens(self):
|
||||
if len(self.state.tokens) > 0:
|
||||
return torch.cat(self.state.tokens, dim=1)
|
||||
return self.state.tokens[0]
|
||||
|
||||
def _tokens_to_list(self, current_tokens, start_col):
|
||||
return current_tokens[0, start_col:].flatten().tolist()
|
||||
|
||||
def _make_new_tokens_tensor(self, hypothesis):
|
||||
return (
|
||||
torch.tensor([hypothesis], dtype=torch.long)
|
||||
.repeat_interleave(self.cfg.beam_size, dim=0)
|
||||
.to(device=self.device)
|
||||
)
|
||||
|
||||
def _evaluate(self, tensor):
|
||||
pass # No-op for PyTorch
|
||||
|
||||
@torch.no_grad()
|
||||
def infer(self, is_last=False):
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
||||
completed = False
|
||||
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
|
||||
####################### Decoding loop
|
||||
logger.info("Decoding loop starts\n")
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
miost_attended_frame = None
|
||||
|
||||
token_len_before_decoding = current_tokens.shape[1]
|
||||
|
||||
generation_progress = []
|
||||
generation = {
|
||||
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
|
||||
"token_len_before_decoding": token_len_before_decoding,
|
||||
#"fire_detected": fire_detected,
|
||||
"frames_len": content_mel_len,
|
||||
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
|
||||
|
||||
# to be filled later
|
||||
"logits_starting": None,
|
||||
|
||||
# to be filled later
|
||||
"no_speech_prob": None,
|
||||
"no_speech": False,
|
||||
|
||||
# to be filled in the loop
|
||||
"progress": generation_progress,
|
||||
}
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
generation_progress_loop = []
|
||||
|
||||
if new_segment:
|
||||
tokens_for_logits = current_tokens
|
||||
else:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens_for_logits = current_tokens[:,-1:]
|
||||
|
||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||
if new_segment:
|
||||
generation["logits_starting"] = Logits(logits[:,:,:])
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
generation["no_speech_prob"] = no_speech_probs[0]
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
generation["no_speech"] = True
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
|
||||
|
||||
# supress blank tokens only at the beginning of the segment
|
||||
if new_segment:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
new_segment = False
|
||||
self.suppress_tokens(logits)
|
||||
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
|
||||
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
|
||||
|
||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
|
||||
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
|
||||
generation_progress_loop.append(("completed",completed))
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
|
||||
# if self.decoder_type == "beam":
|
||||
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
|
||||
|
||||
# logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
# idx = 0
|
||||
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
|
||||
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
|
||||
# if completed:
|
||||
# self.debug_print_tokens(current_tokens)
|
||||
|
||||
# logger.debug("decode stopped because decoder completed")
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
||||
if len(align_heads_in_layer) == 0:
|
||||
continue
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a.unsqueeze(0)
|
||||
else:
|
||||
a = attn_mat[:, head_id, :, :]
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
t = torch.cat(mat, dim=1)
|
||||
tmp.append(t)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
|
||||
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
|
||||
|
||||
generation_progress.append(dict(generation_progress_loop))
|
||||
logger.debug("current tokens" + str(current_tokens.shape))
|
||||
if completed:
|
||||
# # stripping the last token, the eot
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# for some rare cases where the attention fails
|
||||
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||
# TODO: check this
|
||||
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
||||
logger.debug("ommit rewinding from special tokens")
|
||||
self.last_attend_frame = most_attended_frame
|
||||
else:
|
||||
logger.debug(
|
||||
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
||||
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
||||
break
|
||||
else:
|
||||
self.last_attend_frame = most_attended_frame
|
||||
|
||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||
# stripping the last token, the one that is attended too close to the end
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# debug print
|
||||
for i in range(self.cfg.beam_size):
|
||||
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
|
||||
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
|
||||
most_attended_frames[i],
|
||||
current_tokens[i, -1].item(),
|
||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||
))
|
||||
|
||||
# for k,v in generation.items():
|
||||
# print(k,v,file=sys.stderr)
|
||||
# for x in generation_progress:
|
||||
# for y in x.items():
|
||||
# print("\t\t",*y,file=sys.stderr)
|
||||
# print("\t","----", file=sys.stderr)
|
||||
# print("\t", "end of generation_progress_loop", file=sys.stderr)
|
||||
# sys.exit(1)
|
||||
####################### End of decoding loop
|
||||
|
||||
logger.info("End of decoding loop")
|
||||
|
||||
# if attn_of_alignment_heads is not None:
|
||||
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
|
||||
|
||||
# # Lets' now consider only the top hypothesis in the beam search
|
||||
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
|
||||
|
||||
# # debug print: how is the new token attended?
|
||||
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
|
||||
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
|
||||
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
|
||||
# logger.debug("no token generated")
|
||||
# else: # it is, and the max attention is:
|
||||
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
|
||||
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
|
||||
|
||||
|
||||
# let's now operate only with the top beam hypothesis
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
if fire_detected or is_last:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
else:
|
||||
# going to truncate the tokens after the last space
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
|
||||
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
|
||||
|
||||
# text_to_split = self.tokenizer.decode(tokens_to_split)
|
||||
# logger.debug(f"text_to_split: {text_to_split}")
|
||||
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
|
||||
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
|
||||
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
|
||||
if len(split_words) > 1:
|
||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||
else:
|
||||
new_hypothesis = []
|
||||
|
||||
|
||||
### new hypothesis
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.model.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
# TODO: test if this is redundant or not
|
||||
# ret = ret[ret<DEC_PAD]
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
# cleaning cache
|
||||
self.dec_attns = []
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam":
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
self.token_decoder.reset()
|
||||
|
||||
return new_hypothesis, generation
|
||||
return super().infer(is_last)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import torch
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
@@ -7,6 +10,7 @@ class TokenBuffer:
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
@@ -64,7 +68,26 @@ class TokenBuffer:
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
all_tokens = self.pending_token_ids + token_ids
|
||||
|
||||
decoded = tokenizer.decode(all_tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
if replacement_char in decoded:
|
||||
if len(all_tokens) > 1:
|
||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||
|
||||
if replacement_char not in decoded_partial:
|
||||
self.text += decoded_partial
|
||||
self.pending_token_ids = [all_tokens[-1]]
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.text += decoded
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
@@ -1,382 +0,0 @@
|
||||
import base64
|
||||
import gzip
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
SDPA_AVAILABLE = True
|
||||
except (ImportError, RuntimeError, OSError):
|
||||
scaled_dot_product_attention = None
|
||||
SDPA_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
# class LayerNorm(nn.LayerNorm):
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
# class Linear(nn.Linear):
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# return F.linear(
|
||||
# x,
|
||||
# self.weight.to(x.dtype),
|
||||
# None if self.bias is None else self.bias.to(x.dtype),
|
||||
# )
|
||||
|
||||
|
||||
# class Conv1d(nn.Conv1d):
|
||||
# def _conv_forward(
|
||||
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
# ) -> Tensor:
|
||||
# return super()._conv_forward(
|
||||
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
# )
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
import sys ## this is mine, for debugging
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
|
||||
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.key.cache_id = f"{cache_id}_key"
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.value.cache_id = f"{cache_id}_value"
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
self.cache_id = cache_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
#print("MultiHeadAttention forward",file=sys.stderr)
|
||||
q = self.query(x)
|
||||
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
|
||||
# print(mask, kv_cache, xa, file=sys.stderr)
|
||||
|
||||
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
|
||||
# if kv_cache is not None:
|
||||
# print(kv_cache.keys())
|
||||
else:
|
||||
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
|
||||
# if kv_cache is not None:
|
||||
# print(kv_cache.keys())
|
||||
k = kv_cache[self.key.cache_id]
|
||||
v = kv_cache[self.value.cache_id]
|
||||
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
# def qkv_attention(
|
||||
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
# ):
|
||||
# n_batch, n_ctx, n_state = q.shape
|
||||
# scale = (n_state // self.n_head) ** -0.25
|
||||
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
# qk = q @ k
|
||||
# if mask is not None:
|
||||
# qk = qk + mask[:n_ctx, :n_ctx]
|
||||
# # qk = qk.float()
|
||||
|
||||
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
|
||||
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||
a = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||
)
|
||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = None
|
||||
else:
|
||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = qk.detach()
|
||||
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
# print("ResidualAttentionBlock forward",file=sys.stderr)
|
||||
# print(x.shape, file=sys.stderr)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||
)
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor, return_layer_results: bool=False):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1) # BDT -> BTD
|
||||
|
||||
# 两层卷积,2倍降采样
|
||||
# 最终剩下1500帧
|
||||
|
||||
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
|
||||
|
||||
layer_results = []
|
||||
i = 0
|
||||
for block in self.blocks:
|
||||
# print(f"encoder layer {i}")
|
||||
x = block(x)
|
||||
layer_results.append(x)
|
||||
i += 1
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if return_layer_results:
|
||||
return x, layer_results
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = nn.LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
# x = x.to(xa.dtype)
|
||||
|
||||
i = 0
|
||||
for block in self.blocks:
|
||||
# print(f"decoder layer {i}")
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
i += 1
|
||||
|
||||
x = self.ln(x)
|
||||
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
||||
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# mel = mel.to(self.decoder.ln.weight.dtype)
|
||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab >= 51865
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
# 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
@@ -1,501 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from whisper.audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from whisper.decoding import DecodingOptions, DecodingResult
|
||||
from whisper.timing import add_word_timestamps
|
||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from whisper.utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from whisper.model import Whisper
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
# print("HACKED")
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
||||
# mel = pad_or_trim(mel, 3000)
|
||||
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
||||
|
||||
# 判断语种
|
||||
if decode_options.get("language", None) is None:
|
||||
# 如果是单语种模型,直接设成英文
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
# 否则需要前传一次
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
# print(mel_segment.shape)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
# 输出编码器
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
# 词级别时间戳
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
# 几种解码可能失败的情况。这些情况下会重复解码
|
||||
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
||||
needs_fallback = False
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
||||
# t,
|
||||
# decode_result.compression_ratio, compression_ratio_threshold,
|
||||
# -decode_result.avg_logprob, -logprob_threshold,
|
||||
# decode_result.no_speech_prob, no_speech_threshold
|
||||
# ))
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
# 这里output token指的应该是CNN输出的那个东西
|
||||
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
||||
# print("seek segments", seek, content_frames)
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
||||
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
||||
mel_segment = mel[:, seek:]
|
||||
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
||||
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
# 跳过静音部分
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
||||
timestamp_tokens[-1] = False
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
||||
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
||||
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
||||
# 多个的话指向第二个 那如果有三个怎么办?
|
||||
# 否则是个0维tensor
|
||||
|
||||
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
||||
# print("many sentenses", consecutive)
|
||||
last_slice = 0
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
# 获取一个新的语音段
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
||||
# 换句话说就是针对30s长的chunk的语音设计的
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
# print(timestamps)
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
||||
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
# 每个token有自己的时间戳
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(word_end_timestamps) > 0:
|
||||
last_speech_timestamp = word_end_timestamps[-1]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
# 更新结果
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
# print("太长了")
|
||||
# break
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -1 +0,0 @@
|
||||
__version__ = "20230918"
|
||||
139
whisperlivekit/thread_safety.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Thread Safety Configuration for WhisperLiveKit
|
||||
|
||||
This module provides thread safety configuration and utilities.
|
||||
|
||||
Environment Variables:
|
||||
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
|
||||
Set to "0" to disable for single-connection deployments
|
||||
|
||||
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
|
||||
|
||||
Usage:
|
||||
# Enable model locking (default)
|
||||
export WHISPERLIVEKIT_MODEL_LOCK=1
|
||||
|
||||
# Disable for single-connection deployment
|
||||
export WHISPERLIVEKIT_MODEL_LOCK=0
|
||||
|
||||
# Custom timeout
|
||||
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
|
||||
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
|
||||
|
||||
# Global model lock
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
# Log configuration on import
|
||||
if USE_MODEL_LOCK:
|
||||
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
|
||||
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
|
||||
else:
|
||||
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
|
||||
|
||||
|
||||
def get_model_lock():
|
||||
"""Get the global model lock instance"""
|
||||
return _model_lock
|
||||
|
||||
|
||||
def acquire_model_lock(timeout=None):
|
||||
"""
|
||||
Acquire model lock with timeout.
|
||||
|
||||
Args:
|
||||
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
|
||||
|
||||
Returns:
|
||||
bool: True if lock acquired, False on timeout
|
||||
"""
|
||||
if not USE_MODEL_LOCK:
|
||||
return True
|
||||
|
||||
timeout = timeout or LOCK_TIMEOUT
|
||||
acquired = _model_lock.acquire(timeout=timeout)
|
||||
|
||||
if not acquired:
|
||||
logger.error(f"Failed to acquire model lock within {timeout}s")
|
||||
|
||||
return acquired
|
||||
|
||||
|
||||
def release_model_lock():
|
||||
"""Release model lock"""
|
||||
if not USE_MODEL_LOCK:
|
||||
return
|
||||
|
||||
try:
|
||||
_model_lock.release()
|
||||
except RuntimeError:
|
||||
# Lock not held - this is fine
|
||||
pass
|
||||
|
||||
|
||||
class ModelLockContext:
|
||||
"""Context manager for model lock"""
|
||||
|
||||
def __init__(self, timeout=None):
|
||||
self.timeout = timeout
|
||||
self.acquired = False
|
||||
|
||||
def __enter__(self):
|
||||
self.acquired = acquire_model_lock(self.timeout)
|
||||
return self.acquired
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.acquired:
|
||||
release_model_lock()
|
||||
return False
|
||||
|
||||
|
||||
# Concurrency recommendations
|
||||
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
|
||||
RECOMMENDED_WORKERS = 4
|
||||
|
||||
def print_deployment_recommendations():
|
||||
"""Print recommended deployment configuration"""
|
||||
print("\n" + "="*60)
|
||||
print("WhisperLiveKit Deployment Recommendations")
|
||||
print("="*60)
|
||||
|
||||
if USE_MODEL_LOCK:
|
||||
print("⚠️ Model locking is ENABLED")
|
||||
print(" This serializes inference across connections.")
|
||||
print()
|
||||
print("Recommended deployment:")
|
||||
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
|
||||
print(" -k uvicorn.workers.UvicornWorker \\")
|
||||
print(" --worker-connections 1 \\")
|
||||
print(" whisperlivekit.basic_server:app")
|
||||
print()
|
||||
print("Expected capacity:")
|
||||
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
|
||||
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
|
||||
else:
|
||||
print("✅ Model locking is DISABLED")
|
||||
print(" ⚠️ ONLY safe for single-connection deployments")
|
||||
print()
|
||||
print("Recommended deployment:")
|
||||
print(" uvicorn whisperlivekit.basic_server:app \\")
|
||||
print(" --host 0.0.0.0 --port 8000 \\")
|
||||
print(" --workers 1")
|
||||
print()
|
||||
print("Expected capacity:")
|
||||
print(" - 1 concurrent user only")
|
||||
|
||||
print("="*60 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_deployment_recommendations()
|
||||