42 Commits

Author SHA1 Message Date
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
Quentin Fuxa
b67ddea494 bump to 0.2.17 2025-12-08 23:52:00 +01:00
Quentin Fuxa
3192553e20 fixes #307 2025-12-09 10:27:49 +01:00
Quentin Fuxa
f379a243fe Merge pull request #274 from blakkd/patch-1
minor path change
2025-12-09 10:10:32 +01:00
Quentin Fuxa
ec09898a9f fixes #301 2025-12-06 10:19:50 +01:00
blakkd
befbae56c7 minor path change
prevents

```
FileNotFoundError: [Errno 2] No such file or directory: 'whisperlivekit/web/live_transcription.html'
```
2025-11-16 23:47:58 +01:00
Quentin Fuxa
719e8b1a20 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
f1b47178d8 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
59db08e961 loader for full mlx 2024-11-25 23:52:00 +01:00
Quentin Fuxa
6fc20b9562 new dec class 2024-11-21 23:52:00 +01:00
Quentin Fuxa
fac8659161 uses native mlx function for attention 2024-11-21 23:52:00 +01:00
Quentin Fuxa
4d9332ce7d fixes #299 2025-12-05 17:54:14 +01:00
Quentin Fuxa
62444ce746 session parameter required in OnnxWrapper 2025-12-05 15:37:18 +01:00
Quentin Fuxa
2431a6bf91 isolated VAD states per user: .onnx: share a stateless model. .jit: require duplicating the model.
Co-authored-by: eschmidbauer <eschmidbauer@gmail.com>
2025-12-05 15:27:14 +01:00
Quentin Fuxa
d1263e7228 Merge pull request #308 from gzz2000/main
Fix local agreement backend, removing excess parameter, #295
2025-12-05 11:34:05 +01:00
Zizheng Guo
30ddd522a4 Fix local agreement backend, removing excess parameter, fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/295 2025-12-04 16:45:23 +08:00
Quentin Fuxa
635bace09e update archi 2025-11-30 18:39:10 +01:00
Quentin Fuxa
f1113e3eb0 update with LoRA 2025-11-29 18:33:30 +01:00
Quentin Fuxa
cc5f819ce7 hf weights 2025-11-29 17:50:46 +01:00
Quentin Fuxa
82cd24bb75 LoRa path v0 - functional 2025-11-29 17:21:10 +01:00
Quentin Fuxa
d45c397c6a simulstreaming: limit n tokens to prevent hallucinations 2025-11-28 21:41:19 +01:00
Quentin Fuxa
45bf3f57d7 troubleshooting doc for aarch64 systems 2025-11-28 21:40:43 +01:00
Quentin Fuxa
1d88ba9d69 Fixes #294. improve model path backend detection and file extraction 2025-11-27 23:14:00 +01:00
Quentin Fuxa
c0965c6c31 Lines to Segments. Merging dataclasses 2025-11-27 21:54:58 +01:00
Quentin Fuxa
34ddd2ac02 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
345d781e97 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
28cf831701 indicate for context token limits for --max-context-tokens. bump to 0.2.16.dev0 2025-11-25 23:45:15 +01:00
31 changed files with 2330 additions and 493 deletions

View File

@@ -37,9 +37,10 @@ RUN pip3 install --upgrade pip setuptools wheel && \
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \

View File

@@ -1,24 +1,26 @@
<h1 align="center">WhisperLiveKit</h1>
<h1 align="center">WLK</h1>
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
</a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
</p>
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
#### Powered by Leading Research:
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
@@ -143,10 +145,10 @@ async def websocket_endpoint(websocket: WebSocket):
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
@@ -159,6 +161,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
| Translation options | Description | Default |
|-----------|-------------|---------|
@@ -168,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket):
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
| `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
@@ -186,7 +189,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--never-fire` | Never truncate incomplete words | `False` |
| `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` |
| `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |
@@ -264,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization
- `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 KiB

After

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.

View File

@@ -1,109 +0,0 @@
# Available Whisper model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
- `largev3turbo`: ~6GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
## Quick Decision Matrix
**Choose 600M**: Limited resources, close to 0 lag
**Choose 1.3B**: Quality matters
**Choose Transformers**: On Apple Silicon

View File

@@ -0,0 +1,106 @@
# Models and Model Paths
## Defaults
**Default Whisper Model**: `base`
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
**Default Model Cache Directory**: `~/.cache/whisper`
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
**Default Translation Model**: `600M` (NLLB-200-distilled)
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
**Default Translation Backend**: `transformers`
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
---
## Available Whisper model sizes:
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
### How to choose?
#### Language Support
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
#### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
_______________________
# Custom Models:
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2

View File

@@ -1,19 +0,0 @@
# Model Path Formats
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

View File

@@ -1,6 +1,114 @@
# Supported Languages
# Transcription: Supported Language
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
WLK supports transcription in the following languages:
| ISO Code | Language Name |
|----------|---------------------|
| en | English |
| zh | Chinese |
| de | German |
| es | Spanish |
| ru | Russian |
| ko | Korean |
| fr | French |
| ja | Japanese |
| pt | Portuguese |
| tr | Turkish |
| pl | Polish |
| ca | Catalan |
| nl | Dutch |
| ar | Arabic |
| sv | Swedish |
| it | Italian |
| id | Indonesian |
| hi | Hindi |
| fi | Finnish |
| vi | Vietnamese |
| he | Hebrew |
| uk | Ukrainian |
| el | Greek |
| ms | Malay |
| cs | Czech |
| ro | Romanian |
| da | Danish |
| hu | Hungarian |
| ta | Tamil |
| no | Norwegian |
| th | Thai |
| ur | Urdu |
| hr | Croatian |
| bg | Bulgarian |
| lt | Lithuanian |
| la | Latin |
| mi | Maori |
| ml | Malayalam |
| cy | Welsh |
| sk | Slovak |
| te | Telugu |
| fa | Persian |
| lv | Latvian |
| bn | Bengali |
| sr | Serbian |
| az | Azerbaijani |
| sl | Slovenian |
| kn | Kannada |
| et | Estonian |
| mk | Macedonian |
| br | Breton |
| eu | Basque |
| is | Icelandic |
| hy | Armenian |
| ne | Nepali |
| mn | Mongolian |
| bs | Bosnian |
| kk | Kazakh |
| sq | Albanian |
| sw | Swahili |
| gl | Galician |
| mr | Marathi |
| pa | Punjabi |
| si | Sinhala |
| km | Khmer |
| sn | Shona |
| yo | Yoruba |
| so | Somali |
| af | Afrikaans |
| oc | Occitan |
| ka | Georgian |
| be | Belarusian |
| tg | Tajik |
| sd | Sindhi |
| gu | Gujarati |
| am | Amharic |
| yi | Yiddish |
| lo | Lao |
| uz | Uzbek |
| fo | Faroese |
| ht | Haitian Creole |
| ps | Pashto |
| tk | Turkmen |
| nn | Nynorsk |
| mt | Maltese |
| sa | Sanskrit |
| lb | Luxembourgish |
| my | Myanmar |
| bo | Tibetan |
| tl | Tagalog |
| mg | Malagasy |
| as | Assamese |
| tt | Tatar |
| haw | Hawaiian |
| ln | Lingala |
| ha | Hausa |
| ba | Bashkir |
| jw | Javanese |
| su | Sundanese |
| yue | Cantonese |
# Translation: Supported Languages
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
## How to Specify Languages

View File

@@ -40,4 +40,4 @@ This document introduce how to reuse the core components when you do **not** wan
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.

View File

@@ -82,16 +82,43 @@ print(torch.cuda.is_available(), torch.cuda.get_device_name())
```python
import ctranslate2
print("CUDA devices:", ctranslate2.get_cuda_device_count())
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
```
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
---
## Hopper / Blackwell (`sm_121a`) systems
> Reported in issue #276 (NVIDIA DGX Spark)
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
CUDA 12.1a GPUs ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual hints:
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
```bash
# Find your Python environment's Triton directory
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
# Copy the system ptxas to Triton's backend directory
# Replace <triton_path> with the output above
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
```
For example, in a virtual environment:
```bash
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
```
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
### Alternative: Environment variable approach
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
```bash
export CUDA_HOME="/usr/local/cuda-13.0"
@@ -105,7 +132,7 @@ export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
```
After exporting those variables (or adding them to your systemd service / shell profile), restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
---

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.15"
version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -35,6 +35,7 @@ dependencies = [
"torchaudio>=2.0.0",
"torch>=2.0.0",
"huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
@@ -56,6 +57,7 @@ packages = [
"whisperlivekit",
"whisperlivekit.diarization",
"whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper",
"whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers",
@@ -67,4 +69,5 @@ packages = [
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -10,9 +10,9 @@ from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Line, Silence, State, Transcript)
Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
if isinstance(first_item, Silence):
return first_item
items.append(first_item)
while True:
if not queue._queue:
break
@@ -53,15 +53,15 @@ class AudioProcessor:
Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine']
else:
models = TranscriptionEngine(**kwargs)
# Audio processing settings
self.args = models.args
self.sample_rate = 16000
@@ -85,12 +85,14 @@ class AudioProcessor:
# Models and processing
self.asr: Any = models.asr
self.vac_model: Any = models.vac_model
self.vac: Optional[FixedVADIterator] = None
if self.args.vac:
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
else:
self.vac: Optional[FixedVADIterator] = None
if models.vac_session is not None:
vac_model = OnnxWrapper(session=models.vac_session)
self.vac = FixedVADIterator(vac_model)
else:
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None
@@ -104,7 +106,7 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
@@ -115,14 +117,14 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None
if self.args.transcription:
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model:
@@ -180,24 +182,24 @@ class AudioProcessor:
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State:
"""Get current state."""
async with self.lock:
current_time = time()
remaining_transcription = 0
if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0
if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization
return self.state
async def ffmpeg_stdout_reader(self) -> None:
@@ -253,7 +255,7 @@ class AudioProcessor:
async def transcription_processor(self) -> None:
"""Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0
while True:
try:
# item = await self.transcription_queue.get()
@@ -309,12 +311,12 @@ class AudioProcessor:
if new_tokens:
candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto)
async with self.lock:
self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript
@@ -324,13 +326,13 @@ class AudioProcessor:
if self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done()
if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue:
@@ -351,18 +353,21 @@ class AudioProcessor:
if item.has_ended:
self.diarization.insert_silence(item.duration)
continue
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments
diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
@@ -424,22 +429,22 @@ class AudioProcessor:
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
should_push = (response != self.last_response_content)
if should_push:
yield response
self.last_response_content = response
if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
await asyncio.sleep(0.05)
except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks."""
self.all_tasks_for_cleanup = []
@@ -464,21 +469,21 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
@@ -491,7 +496,7 @@ class AudioProcessor:
return
await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)):
if task.done():
exc = task.exception()
@@ -501,13 +506,13 @@ class AudioProcessor:
else:
logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
break
except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None:
"""Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.")
@@ -515,7 +520,7 @@ class AudioProcessor:
for task in self.all_tasks_for_cleanup:
if task and not task.done():
task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -553,7 +558,7 @@ class AudioProcessor:
if not message:
logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True
if self.transcription_queue:
await self.transcription_queue.put(SENTINEL)
@@ -594,7 +599,7 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0:
return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
@@ -611,7 +616,7 @@ class AudioProcessor:
if res is not None:
if "start" in res and self.current_silence:
await self._end_silence()
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end")

View File

@@ -1,5 +1,6 @@
import logging
import sys
import threading
from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine:
_instance = None
_initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, **kwargs):
if TranscriptionEngine._initialized:
return
# Thread-safe initialization check
with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = {
"host": "localhost",
"port": 8000,
@@ -36,7 +47,6 @@ class TranscriptionEngine:
"punctuation_split": False,
"target_language": "",
"vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04,
"log_level": "DEBUG",
"ssl_certfile": None,
@@ -59,6 +69,7 @@ class TranscriptionEngine:
"model_cache_dir": None,
"model_dir": None,
"model_path": None,
"lora_path": None,
"lan": "auto",
"direct_english_translation": False,
}
@@ -78,15 +89,19 @@ class TranscriptionEngine:
self.asr = None
self.tokenizer = None
self.diarization = None
self.vac_model = None
self.vac_session = None
if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad
# Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx)
from whisperlivekit.silero_vad_iterator import is_onnx_available
if is_onnx_available():
from whisperlivekit.silero_vad_iterator import load_onnx_session
self.vac_session = load_onnx_session()
else:
logger.warning(
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
)
backend_policy = self.args.backend_policy
if self.args.transcription:
if backend_policy == "simulstreaming":
@@ -168,16 +183,13 @@ class TranscriptionEngine:
}
translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr):
if args.backend_policy == "simulstreaming":
if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(asr)
else:
online = OnlineASRProcessor(asr)
return online
return SimulStreamingOnlineProcessor(asr)
return OnlineASRProcessor(asr)
def online_diarization_factory(args, diarization_backend):

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
def insert_audio_chunk(self, pcm_array: np.ndarray):
"""Buffer audio for the next diarization step."""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
self.custom_source.push_audio(pcm_array)
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self):
"""Close the audio source."""

View File

@@ -7,7 +7,7 @@ from typing import List
import numpy as np
import soundfile as sf
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
@@ -16,9 +16,10 @@ class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
self.lora_path = lora_path
if lan == "auto":
self.original_language = None
else:
@@ -47,24 +48,23 @@ class WhisperASR(ASRBase):
sep = " "
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from whisperlivekit.whisper import load_model as load_model
from whisperlivekit.whisper import load_model as load_whisper_model
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)
resolved_path = resolve_model_path(model_dir)
if resolved_path.is_dir():
pytorch_path, _, _ = model_path_and_type(resolved_path)
if pytorch_path is None:
model_info = detect_model_format(resolved_path)
if not model_info.has_pytorch:
raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}"
)
resolved_path = pytorch_path
)
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_model(str(resolved_path))
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
if model_size is None:
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
return load_model(model_size, download_root=cache_dir)
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
def transcribe(self, audio, init_prompt=""):
options = dict(self.transcribe_kargs)
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model()
self.use_vad_opt = False
self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript

View File

@@ -10,7 +10,7 @@ import numpy as np
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.warmup import warmup_asr
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
@@ -77,6 +77,7 @@ def backend_factory(
model_cache_dir,
model_dir,
model_path,
lora_path,
direct_english_translation,
buffer_trimming,
buffer_trimming_sec,
@@ -87,16 +88,20 @@ def backend_factory(
backend_choice = backend
custom_reference = model_path or model_dir
resolved_root = None
pytorch_checkpoint = None
has_mlx_weights = False
has_fw_weights = False
has_pytorch = False
if custom_reference:
resolved_root = resolve_model_path(custom_reference)
if resolved_root.is_dir():
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
model_info = detect_model_format(resolved_root)
has_mlx_weights = model_info.compatible_whisper_mlx
has_fw_weights = model_info.compatible_faster_whisper
has_pytorch = model_info.has_pytorch
else:
pytorch_checkpoint = resolved_root
# Single file provided
has_pytorch = True
if backend_choice == "openai-api":
logger.debug("Using OpenAI API.")
@@ -121,8 +126,8 @@ def backend_factory(
model_override = str(resolved_root) if resolved_root is not None else None
else:
asr_cls = WhisperASR
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
if custom_reference and model_override is None:
model_override = str(resolved_root) if resolved_root is not None else None
if custom_reference and not has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
)
@@ -134,12 +139,14 @@ def backend_factory(
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_override,
lora_path=lora_path if backend_choice == "whisper" else None,
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else:
tgt_language = lan # Whisper transcribes in this language
@@ -148,9 +155,9 @@ def backend_factory(
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming

View File

@@ -1,49 +1,195 @@
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
@dataclass
class ModelInfo:
"""Information about detected model format and files in a directory."""
path: Optional[Path] = None
pytorch_files: List[Path] = field(default_factory=list)
compatible_whisper_mlx: bool = False
compatible_faster_whisper: bool = False
@property
def has_pytorch(self) -> bool:
return len(self.pytorch_files) > 0
@property
def is_sharded(self) -> bool:
return len(self.pytorch_files) > 1
@property
def primary_pytorch_file(self) -> Optional[Path]:
"""Return the primary PyTorch file (or first shard for sharded models)."""
if not self.pytorch_files:
return None
return self.pytorch_files[0]
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
"""
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
CTranslate2 models have specific companion files that distinguish them
from PyTorch .bin files.
"""
n_indicators = 0
for indicator in CT2_INDICATOR_FILES: #test 1
if (directory / indicator).exists():
n_indicators += 1
if n_indicators == 0:
return False
config_path = directory / "config.json" #test 2
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
if config.get("model_type") == "whisper": #test 2
return False
except (json.JSONDecodeError, IOError):
pass
return True
def _collect_pytorch_files(directory: Path) -> List[Path]:
"""
Collect all PyTorch checkpoint files from a directory.
Handles:
- Single files: model.safetensors, pytorch_model.bin, *.pt
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
- Index-based sharded models (reads index file to find shards)
Returns files sorted appropriately (shards in order, or single file).
"""
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
index_path = directory / index_name
if index_path.exists():
try:
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
if weight_map:
shard_names = sorted(set(weight_map.values()))
shards = [directory / name for name in shard_names if (directory / name).exists()]
if shards:
return shards
except (json.JSONDecodeError, IOError):
pass
sharded_groups = {}
single_files = {}
for file in directory.iterdir():
if not file.is_file():
continue
filename = file.name
suffix = file.suffix.lower()
if filename.startswith("adapter_"):
continue
match = SHARDED_PATTERN.match(filename)
if match:
base_name, shard_idx, total_shards, ext = match.groups()
key = (base_name, ext, int(total_shards))
if key not in sharded_groups:
sharded_groups[key] = []
sharded_groups[key].append((int(shard_idx), file))
continue
if filename == "model.safetensors":
single_files[0] = file # Highest priority
elif filename == "pytorch_model.bin":
single_files[1] = file
elif suffix == ".pt":
single_files[2] = file
elif suffix == ".safetensors" and not filename.startswith("adapter"):
single_files[3] = file
for (base_name, ext, total_shards), shards in sharded_groups.items():
if len(shards) == total_shards:
return [path for _, path in sorted(shards)]
for priority in sorted(single_files.keys()):
return [single_files[priority]]
return []
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
"""
Detect the model format in a given path.
This function analyzes a file or directory to determine:
- What PyTorch checkpoint files are available (including sharded models)
- Whether the directory contains MLX Whisper weights
- Whether the directory contains Faster-Whisper (CTranslate2) weights
Args:
model_path: Path to a model file or directory
Returns:
ModelInfo with detected format information
"""
path = Path(model_path)
info = ModelInfo(path=path)
if path.is_file():
suffix = path.suffix.lower()
if suffix in {".pt", ".safetensors", ".bin"}:
info.pytorch_files = [path]
return info
if not path.is_dir():
return info
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
if filename in MLX_WHISPER_MARKERS:
info.compatible_whisper_mlx = True
if filename in FASTER_WHISPER_MARKERS:
if _is_ct2_model_bin(path, filename):
info.compatible_faster_whisper = True
info.pytorch_files = _collect_pytorch_files(path)
return info
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
"""
Inspect the provided path and determine which model formats are available.
This is a compatibility wrapper around detect_model_format().
Returns:
pytorch_path: Path to a PyTorch checkpoint (if present).
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
compatible_whisper_mlx: True if MLX weights exist in this folder.
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
"""
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pytorch_path: Optional[Path] = None
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
pytorch_path = path
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
if path.is_dir():
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
suffix = file.suffix.lower()
if filename in {"weights.npz", "weights.safetensors"}:
compatible_whisper_mlx = True
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
compatible_faster_whisper = True
elif suffix in {".pt", ".safetensors"}:
pytorch_path = file
elif filename == "pytorch_model.bin":
pytorch_path = file
if pytorch_path is None:
fallback = path / "pytorch_model.bin"
if fallback.exists():
pytorch_path = fallback
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
info = detect_model_format(model_path)
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
def resolve_model_path(model_path: Union[str, Path]) -> Path:
@@ -59,7 +205,7 @@ def resolve_model_path(model_path: Union[str, Path]) -> Path:
try:
from huggingface_hub import snapshot_download
except ImportError as exc: # pragma: no cover - optional dependency guard
except ImportError as exc:
raise FileNotFoundError(
f"Model path '{model_path}' does not exist locally and huggingface_hub "
"is not installed to download it."

View File

@@ -106,6 +106,13 @@ def parse_args():
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lora-path",
type=str,
default=None,
dest="lora_path",
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
)
parser.add_argument(
"--lan",
"--language",

View File

@@ -8,6 +8,15 @@ import torch
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
"""
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device)
@@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
return model
class OnnxWrapper():
"""ONNX Runtime wrapper for Silero VAD model."""
class OnnxSession():
"""
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime
opts = onnxruntime.SessionOptions()
@@ -32,13 +41,28 @@ class OnnxWrapper():
else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states()
self.path = path
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
@@ -101,38 +125,20 @@ class OnnxWrapper():
return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
"""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
"""Get the path to the ONNX model file."""
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
if opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = 'silero_vad.jit'
model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name
@@ -143,17 +149,39 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
)
else:
model_path = Path(model_path)
if onnx:
try:
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError(
"ONNX runtime not available. Install with: pip install onnxruntime\n"
"Or use JIT model by setting onnx=False"
return model_path
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
"""
Load a shared ONNX session for Silero VAD.
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
)
else:
model = init_jit_model(str(model_path))
model_path = Path(model_path)
model = init_jit_model(str(model_path))
return model
@@ -285,13 +313,14 @@ class FixedVADIterator(VADIterator):
if __name__ == "__main__":
model = load_silero_vad(onnx=False)
vad = FixedVADIterator(model)
# vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer)
print(f" 512 samples: {result}")
# test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer)
result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -11,7 +11,7 @@ import torch
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt
else:
mlx_model_mapping = {}
MLXAlignAtt = None
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
if HAS_FASTER_WHISPER:
from faster_whisper import WhisperModel
@@ -36,50 +38,47 @@ else:
MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
):
def __init__(self, asr, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_alignatt_instance()
self.model = self._create_alignatt()
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def load_new_alignatt_instance(self):
"""Initialize AlignAtt decoder using the shared model."""
self.model = AlignAtt(
cfg=self.asr.cfg,
loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def _create_alignatt(self):
"""Create the AlignAtt decoder instance based on ASR mode."""
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
else:
return AlignAtt(
cfg=self.asr.cfg,
loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto
def end_silence(self, silence_duration, offset):
"""
Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
"""
"""Handle silence period."""
self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(16000 * silence_duration)
if gap_len > 0:
gap_silence = torch.zeros(gap_len)
if self.asr.use_full_mlx:
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence)
if long_silence:
self.model.refresh_segment(complete=True)
@@ -87,11 +86,12 @@ class SimulStreamingOnlineProcessor:
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
self.model.insert_audio(audio_tensor)
self.end = audio_stream_end_time
if self.asr.use_full_mlx:
self.model.insert_audio(audio)
else:
audio_tensor = torch.from_numpy(audio).float()
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event."""
@@ -120,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
except Exception as e:
@@ -130,6 +129,10 @@ class SimulStreamingOnlineProcessor:
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
@@ -139,9 +142,14 @@ class SimulStreamingOnlineProcessor:
def __del__(self):
gc.collect()
torch.cuda.empty_cache()
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
try:
torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR():
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
@@ -158,35 +166,25 @@ class SimulStreamingASR():
self.fast_encoder = False
self._resolved_model_path = None
self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto")
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
compatible_whisper_mlx, compatible_faster_whisper = True, True
if self.model_path:
resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path)
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
if self.pytorch_path:
self.model_name = self.pytorch_path.stem
else:
self.model_name = Path(self.model_path).stem
model_info = detect_model_format(resolved_model_path)
compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper
if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
)
)
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_name = self.model_size
else:
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
@@ -201,6 +199,10 @@ class SimulStreamingASR():
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper":
self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
if not hasattr(self, '_full_mlx_disabled'):
self.use_full_mlx = True
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual,
@@ -212,7 +214,7 @@ class SimulStreamingASR():
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.direct_english_translation,
task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
@@ -225,20 +227,36 @@ class SimulStreamingASR():
else:
self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None
if self.encoder_backend == "mlx-whisper":
print('Simulstreaming will use MLX whisper to increase encoding speed.')
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
self.shared_model = None
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None:
mlx_model = str(self._resolved_model_path)
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model = mlx_model_mapping.get(self.model_name)
if not mlx_model:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper":
print('Simulstreaming will use Faster Whisper for the encoder.')
print('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path)
else:
@@ -248,7 +266,20 @@ class SimulStreamingASR():
device='auto',
compute_type='auto',
)
self.shared_model = self.load_model()
self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
@@ -292,11 +323,14 @@ class SimulStreamingASR():
return True
def load_model(self):
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model(
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
name=model_ref,
download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads
custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path,
)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:

View File

@@ -47,9 +47,24 @@ class DecoderState:
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.kv_cache = {}
# Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache
# Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None:
self.token_decoder.reset()

View File

@@ -0,0 +1,11 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -0,0 +1,76 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -0,0 +1,219 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -0,0 +1,756 @@
"""
MLX whisper AlignAtt streaming decoder
"""
import logging
from time import time
from typing import Any, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
"""Return tokens as MLX array."""
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
"""Return tokens as MLX array repeated for beam search."""
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
"""Trim words from the beginning of the context."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
"""Append token IDs to the buffer, handling incomplete UTF-8."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""
Apply median filter along the last axis.
Args:
x: Input array of shape (..., T)
filter_width: Width of the median filter (should be odd)
Returns:
Filtered array of same shape
"""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result_shape = list(shape)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt:
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
This class runs entirely on MLX, with no PyTorch dependencies for inference.
"""
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
"""
Initialize MLX AlignAtt decoder.
Args:
cfg: AlignAtt configuration
mlx_model: MLX Whisper model (full model, not just encoder)
"""
self.model = mlx_model
self.cfg = cfg
logger.info(f"MLX Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
def warmup(self, audio: np.ndarray):
"""Warmup the model with sample audio."""
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("MLX model warmed up successfully")
except Exception as e:
logger.exception(f"MLX model warmup failed: {e}")
def create_tokenizer(self, language=None):
"""Create tokenizer for the given language."""
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
"""Initialize context buffer."""
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev]
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
"""Initialize token sequence."""
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
"""Trim context if too long."""
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
"""Refresh segment state."""
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
"""Check if we should fire at word boundary (CIF-based)."""
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True
def _current_tokens(self) -> mx.array:
"""Get current token sequence for decoding."""
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
# Concatenate all tokens
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens: mx.array):
"""Debug print token sequences."""
tokens_np = np.array(tokens)
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
def segments_len(self) -> float:
"""Get total length of audio segments in seconds."""
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self) -> bool:
"""Check if we have enough audio to process."""
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment: np.ndarray = None):
"""Insert audio segment into buffer."""
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
# Convert MLX array to list for context
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
def _suppress_tokens(self, logits: mx.array) -> mx.array:
"""Apply token suppression to logits."""
if self.state.suppress_tokens:
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
"""Language detection from encoder features."""
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def infer(self, is_last: bool = False) -> List[ASRToken]:
"""
Main inference method.
Args:
is_last: Whether this is the final chunk
Returns:
List of timestamped ASR tokens
"""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
if len(self.state.segments) > 1:
input_segments = np.concatenate(self.state.segments, axis=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
mx.eval(encoder_feature)
end_encode = time()
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
completed = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding]
break
if new_segment:
tokens_for_logits = current_tokens
else:
tokens_for_logits = current_tokens[:, -1:]
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
)
else:
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # Last token logits
# Suppress tokens at segment start
if new_segment:
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
new_segment = False
logits = self._suppress_tokens(logits)
current_tokens, completed = self.state.token_decoder.update(
current_tokens, logits, sum_logprobs
)
mx.eval(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = self._process_cross_attention(
accumulated_cross_attns, content_mel_len
)
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
most_attended_frames_np = np.array(most_attended_frames)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames_np.tolist()
]
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
most_attended_frame = int(most_attended_frames_np[0])
l_absolute_timestamps.append(absolute_timestamps[0])
if completed:
current_tokens = current_tokens[:, :-1]
break
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
current_tokens_np = np.array(current_tokens)
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
current_tokens = current_tokens[:, :-1]
break
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
if fire_detected or is_last:
new_hypothesis = tokens_to_split
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
else:
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[List[mx.array]],
content_mel_len: int
) -> mx.array:
"""
Process cross-attention weights for alignment.
Args:
cross_attns: List of cross-attention from each forward pass
Each element is a list of mx.arrays per layer
content_mel_len: Length of actual audio content
Returns:
Processed attention tensor, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = mx.concatenate(mat, axis=1)
tmp.append(t)
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads

View File

@@ -68,4 +68,40 @@ def load_mlx_encoder(
model.update(encoder_weights)
mx.eval(model.parameters())
return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return []
# input_segments is concatenation of audio, it's one array
@@ -484,7 +483,19 @@ class AlignAtt:
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding] # Discard all new tokens
break
if new_segment:
tokens_for_logits = current_tokens
@@ -496,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
# Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
@@ -616,8 +631,10 @@ class AlignAtt:
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
except IndexError:
# Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
@@ -631,11 +648,15 @@ class AlignAtt:
)
timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk
# Hold incomplete tokens for next chunk (with limit to prevent hallucination accumulation)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10 # Real incomplete UTF-8 chars are at most a few tokens
if split_words and replacement_char in split_words[-1]:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}")
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk")
else:
logger.warning(f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens (exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)")
return timestamped_words
@@ -702,4 +723,4 @@ class AlignAtt:
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
return attn_of_alignment_heads
return attn_of_alignment_heads

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass()
class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool:
return False
@@ -114,6 +115,9 @@ class Segment(TimedText):
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None
@classmethod
def from_tokens(
cls,
@@ -141,17 +145,13 @@ class Segment(TimedText):
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self) -> Dict[str, Any]:
"""Serialize the line for frontend consumption."""
"""Serialize the segment for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
@@ -163,29 +163,13 @@ class Line(TimedText):
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
@dataclass
class PuncSegment(Segment):
pass
class SilentLine(Line):
class SilentSegment(Segment):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
@@ -196,7 +180,7 @@ class SilentLine(Line):
class FrontData():
status: str = ''
error: str = ''
lines: list[Line] = field(default_factory=list)
lines: list[Segment] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
buffer_translation: str = ''

View File

@@ -1,8 +1,8 @@
from time import time
from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
SilentLine, SpeakerSegment,
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
SilentSegment, SpeakerSegment,
TimedText)
@@ -27,6 +27,14 @@ class TokensAlignment:
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
self.validated_segments: List[Segment] = []
self.current_line_tokens: List[ASRToken] = []
self.diarization_buffer: List[ASRToken] = []
self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = []
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
@@ -39,27 +47,30 @@ class TokensAlignment:
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
if ts.is_within(segment):
if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
previous_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
@@ -67,19 +78,47 @@ class TokensAlignment:
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
final_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = []
segment_start_idx = 0
self.unvalidated_tokens += self.new_tokens
for i, token in enumerate(self.unvalidated_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i],
)
if previous_segment:
new_punc_segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
new_punc_segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
)
new_punc_segments.append(segment)
segment_start_idx = i+1
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
return new_punc_segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
@@ -102,8 +141,8 @@ class TokensAlignment:
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
@@ -121,18 +160,18 @@ class TokensAlignment:
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
segments = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
segments = [punctuation_segments[0]]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
if segment.speaker == segments[-1].speaker:
if segments[-1].text:
segments[-1].text += segment.text
segments[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
segments.append(segment)
return lines, diarization_buffer
return segments, diarization_buffer
def get_lines(
@@ -140,40 +179,42 @@ class TokensAlignment:
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
segments, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
for token in self.new_tokens:
if isinstance(token, Silence):
if self.current_line_tokens:
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
self.current_line_tokens.append(token)
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
if segments and segments[-1].is_silence():
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
segments.append(SilentSegment(
start=current_silence.start,
end=end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
return segments, diarization_buffer, self.new_translation_buffer.text

View File

@@ -108,7 +108,7 @@ def available_models() -> List[str]:
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
next to the given checkpoint, usefull for distilled models/MLX models.
"""
candidates = []
if os.path.isdir(path):
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
@@ -264,9 +301,49 @@ def _collapse_hf_module_name(module: str):
return module
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
"""
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
If lora_path is a local directory containing adapter files, returns it as-is.
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
"""
if not lora_path:
return None
# Check if it's already a valid local path
if os.path.isdir(lora_path):
config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.isfile(config_path):
return lora_path
# Try to download from HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
return local_path
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
)
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path:
return
# Resolve path (handles HuggingFace Hub download)
lora_path = _resolve_lora_path(lora_path)
if not lora_path:
return
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
@@ -319,6 +396,75 @@ def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str])
)
def _load_checkpoint(
file_path: Union[str, Path],
device: str,
in_memory: bool = False,
checkpoint_bytes: Optional[bytes] = None,
) -> Dict[str, torch.Tensor]:
"""
Load a checkpoint from a single file.
Handles .pt, .bin, and .safetensors formats.
"""
if checkpoint_bytes is not None:
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
)
return load_file(str(file_path), device=device)
else:
if in_memory:
with open(file_path, "rb") as f:
checkpoint_bytes = f.read()
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
else:
with open(file_path, "rb") as fp:
return torch.load(fp, map_location=device)
def _load_sharded_checkpoint(
shard_files: List[Path],
device: str,
) -> Dict[str, torch.Tensor]:
"""
Load a sharded checkpoint (multiple .safetensors or .bin files).
Merges all shards into a single state dict.
"""
merged_state_dict = {}
first_suffix = shard_files[0].suffix.lower()
if first_suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
)
for shard_path in shard_files:
shard_dict = load_file(str(shard_path), device=device)
merged_state_dict.update(shard_dict)
else:
for shard_path in shard_files:
with open(shard_path, "rb") as fp:
shard_dict = torch.load(fp, map_location=device)
if isinstance(shard_dict, dict):
merged_state_dict.update(shard_dict)
return merged_state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
@@ -336,6 +482,8 @@ def load_model(
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
or a sharded model directory with files like model-00001-of-00002.safetensors.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
@@ -350,16 +498,51 @@ def load_model(
model : Whisper
The Whisper ASR model instance
"""
from whisperlivekit.model_paths import detect_model_format
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
checkpoint = None
model_path_for_config = name # Used to find config.json for dims inference
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
if in_memory:
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
else:
checkpoint = _load_checkpoint(checkpoint_file, device)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
if in_memory:
with open(name, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(name, device)
model_path_for_config = name
elif os.path.isdir(name):
model_info = detect_model_format(name)
if not model_info.has_pytorch:
raise RuntimeError(
f"No PyTorch checkpoint found in directory {name}. "
f"Expected .pt, .bin, or .safetensors file(s)."
)
if model_info.is_sharded:
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
else:
single_file = model_info.pytorch_files[0]
if in_memory:
with open(single_file, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(single_file, device)
model_path_for_config = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
@@ -369,34 +552,23 @@ def load_model(
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
dims = _infer_dims_from_config(model_path_for_config)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
@@ -416,8 +588,13 @@ def load_model(
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device)