21 Commits

Author SHA1 Message Date
Quentin Fuxa
d9a4c8dcb2 Refactor transcription and diarization handling with token-by-token validation. Introduce segment buffers for ephemeral content and update API to return structured segment data. Enhance silence handling and improve web interface for text transcripts. 2025-11-30 16:39:27 +01:00
Quentin Fuxa
4fb735a784 new token treatment only iar 2025-11-30 15:16:36 +01:00
Quentin Fuxa
d2f998cb7e val 2025-11-30 14:37:37 +01:00
Quentin Fuxa
7b18917f2b LoRA archi 2025-11-30 12:30:18 +01:00
Quentin Fuxa
f1113e3eb0 update with LoRA 2025-11-29 18:33:30 +01:00
Quentin Fuxa
cc5f819ce7 hf weights 2025-11-29 17:50:46 +01:00
Quentin Fuxa
82cd24bb75 LoRa path v0 - functional 2025-11-29 17:21:10 +01:00
Quentin Fuxa
d45c397c6a simulstreaming: limit n tokens to prevent hallucinations 2025-11-28 21:41:19 +01:00
Quentin Fuxa
45bf3f57d7 troubleshooting doc for aarch64 systems 2025-11-28 21:40:43 +01:00
Quentin Fuxa
1d88ba9d69 Fixes #294. improve model path backend detection and file extraction 2025-11-27 23:14:00 +01:00
Quentin Fuxa
c0965c6c31 Lines to Segments. Merging dataclasses 2025-11-27 21:54:58 +01:00
Quentin Fuxa
34ddd2ac02 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
345d781e97 update doc 2025-11-25 23:20:00 +01:00
Quentin Fuxa
28cf831701 indicate for context token limits for --max-context-tokens. bump to 0.2.16.dev0 2025-11-25 23:45:15 +01:00
Quentin Fuxa
60c62f8f84 troubleshooting #271 #276 #284 #286 2025-11-25 23:31:46 +01:00
Quentin Fuxa
7faa21f95f alignatt: enable model sharing by removing hooks and centralizing session state. Solves #282
Co-authored-by: Emmanuel Schmidbauer <eschmidbauer@gmail.com>
2025-11-25 23:07:42 +01:00
Quentin Fuxa
4e9f951551 correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
870141298c isort 2025-11-23 11:20:00 +01:00
Quentin Fuxa
872faa422a correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
fc9cb66813 disabling vac is not advised 2025-11-23 11:20:00 +01:00
Quentin Fuxa
a175d1a327 fixes silence detected but never reported by silero 2025-11-23 11:20:00 +01:00
53 changed files with 3076 additions and 1693 deletions

View File

@@ -1,24 +1,26 @@
<h1 align="center">WhisperLiveKit</h1>
<h1 align="center">WLK</h1>
<p align="center"><b>WhisperLiveKit: Ultra-low-latency, self-hosted speech-to-text with speaker identification</b></p>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://huggingface.co/qfuxa/whisper-base-french-lora">
<img alt="Hugging Face Weights" src="https://img.shields.io/badge/🤗-Hugging%20Face%20Weights-yellow" />
</a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
</p>
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
#### Powered by Leading Research:
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- Simul-[Whisper](https://arxiv.org/pdf/2406.10052)/[Streaming](https://arxiv.org/abs/2506.17077) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
@@ -51,9 +53,11 @@ pip install whisperlivekit
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
> - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Use it to capture audio from web pages.
@@ -96,11 +100,13 @@ wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
transcription_engine = None
@@ -139,15 +145,15 @@ async def websocket_endpoint(websocket: WebSocket):
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
| `--no-vac` | Disable Voice Activity Controller | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` |
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` |
@@ -155,6 +161,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
| Translation options | Description | Default |
|-----------|-------------|---------|
@@ -164,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket):
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
| `--disable-punctuation-split` | [NOT FUNCTIONAL IN 0.2.15 / 0.2.16] Disable punctuation based splits. See #214 | `False` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
@@ -182,8 +189,7 @@ async def websocket_endpoint(websocket: WebSocket):
| `--never-fire` | Never truncate incomplete words | `False` |
| `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| `--max-context-tokens` | Maximum context tokens | Depends on model used, but usually 448. |

View File

@@ -1,258 +0,0 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 KiB

After

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -1,53 +1,22 @@
# WhisperLiveKit WebSocket API Documentation
> !! **Note**: The new API structure described in this document is currently under deployment.
This documentation is intended for devs who want to build custom frontends.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends updates as audio is processed, allowing clients to display live transcription results with minimal latency.
---
## Legacy API (Current)
## Endpoints
### Message Structure
The current API sends complete state snapshots on each update (several time per second)
```typescript
{
"type": str,
"status": str,
"lines": [
{
"speaker": int,
"text": str,
"start": float,
"end": float,
"translation": str | null,
"detected_language": str
}
],
"buffer_transcription": str,
"buffer_diarization": str,
"remaining_time_transcription": float,
"remaining_time_diarization": float
}
```
| Endpoint | Description |
|----------|-------------|
| `/` | Main web interface with visual styling |
| `/text` | Simple text-based interface for easy copy/paste (debug/development) |
| `/asr` | WebSocket endpoint for audio streaming |
---
## New API (Under Development)
### Philosophy
Principles:
- **Incremental Updates**: Only updates and new segments are sent
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
## Message Format
### Transcript Update (Server → Client)
```typescript
{
@@ -58,22 +27,11 @@ Principles:
"id": number,
"speaker": number,
"text": string,
"start_speaker": float,
"start": float,
"end": float,
"start_speaker": string, // HH:MM:SS format
"start": string, // HH:MM:SS format
"end": string, // HH:MM:SS format
"language": string | null,
"translation": string,
"words": [
{
"text": string,
"start": float,
"end": float,
"validated": {
"text": boolean,
"speaker": boolean,
}
}
],
"buffer": {
"transcription": string,
"diarization": string,
@@ -94,9 +52,10 @@ Principles:
```json
{
"type": "config",
"useAudioWorklet": true / false
"useAudioWorklet": true
}
```
- `useAudioWorklet`: If `true`, client should use AudioWorklet for PCM streaming. If `false`, use MediaRecorder for WebM.
#### Ready to Stop Message (sent after processing complete)
```json
@@ -104,6 +63,7 @@ Principles:
"type": "ready_to_stop"
}
```
Indicates all audio has been processed and the client can safely close the connection.
---
@@ -113,152 +73,179 @@ Principles:
| Field | Type | Description |
|-------|------|-------------|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
| `id` | `number` | Unique identifier for this segment. |
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
| `words` | `Array` | Array of word-level objects with timing and validation information. |
| `buffer` | `Object` | Per-segment temporary buffers, see below |
### Word Object
| Field | Type | Description |
|-------|------|-------------|
| `text` | `string` | The word text. |
| `start` | `number` | Start timestamp (seconds) of this word. |
| `end` | `number` | End timestamp (seconds) of this word. |
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
| `text` | `string` | Validated transcription text. |
| `start_speaker` | `string` | Timestamp (HH:MM:SS) when this speaker segment began. |
| `start` | `string` | Timestamp (HH:MM:SS) of the first word. |
| `end` | `string` | Timestamp (HH:MM:SS) of the last word. |
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until detected. |
| `translation` | `string` | Validated translation text. |
| `buffer` | `Object` | Per-segment temporary buffers (see below). |
### Buffer Object (Per-Segment)
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
Buffers are **ephemeral**. They should be displayed to the user but are overwritten on each update. Only the **last non-silent segment** contains buffer content.
| Field | Type | Description |
|-------|------|-------------|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
| `transcription` | `string` | Text pending validation (waiting for more context). |
| `diarization` | `string` | Text pending speaker assignment (diarization hasn't caught up). |
| `translation` | `string` | Translation pending validation. |
### Metadata Fields
| Field | Type | Description |
|-------|------|-------------|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription. |
| `remaining_time_diarization` | `float` | Seconds of audio waiting for diarization. |
### Status Values
| Status | Description |
|--------|-------------|
| `active_transcription` | Normal operation, transcription is active. |
| `no_audio_detected` | No audio has been detected yet. |
| `no_audio_detected` | No audio/speech has been detected yet. |
---
## Update Behavior
## Behavior Notes
### Incremental Updates
### Silence Handling
The API sends **only changed or new segments**. Clients should:
- **Short silences (< 2 seconds)** are filtered out and not displayed.
- Only significant pauses appear as silence segments with `speaker: -2`.
- Consecutive same-speaker segments are merged even across short silences.
1. Maintain a local map of segments by ID
2. When receiving an update, merge/update segments by ID
3. Render only the changed segments
### Update Frequency
### Language Detection
- **Active transcription**: ~20 updates/second (every 50ms)
- **During silence**: ~2 updates/second (every 500ms) to reduce bandwidth
When language is detected for a segment:
### Token-by-Token Validation (Diarization Mode)
```jsonc
// Update 1: No language yet
{
"segments": [
{"id": 1, "speaker": 1, "text": "May see", "language": null}
]
}
// Update 2: Same segment ID, language now detected
{
"segments": [
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
]
}
```
**Client behavior**: **Replace** the existing segment with the same ID.
### Buffer Behavior
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
#### Example: Translation with diarization and translation
```jsonc
// Update 1
When diarization is enabled, text is validated **token-by-token** as soon as diarization covers each token, rather than waiting for punctuation. This provides:
- Faster text validation
- More responsive speaker attribution
- Buffer only contains tokens that diarization hasn't processed yet
---
## Example Messages
### Normal Transcription
```json
{
"type": "transcript_update",
"status": "active_transcription",
"segments": [
{
"id": 1,
"speaker": 1,
"text": "Hello world, how are",
"text": "Hello, how are you today?",
"start_speaker": "0:00:02",
"start": "0:00:02",
"end": "0:00:05",
"language": "en",
"translation": "",
"buffer": {
"transcription": " I'm doing",
"diarization": "",
"translation": ""
}
}
],
"metadata": {
"remaining_time_transcription": 0.5,
"remaining_time_diarization": 0
}
}
```
### With Diarization Buffer
```json
{
"type": "transcript_update",
"status": "active_transcription",
"segments": [
{
"id": 1,
"speaker": 1,
"text": "The meeting starts at nine.",
"start_speaker": "0:00:03",
"start": "0:00:03",
"end": "0:00:06",
"language": "en",
"translation": "",
"buffer": {
"transcription": "",
"diarization": " you on",
"translation": "Bonjour le monde"
"diarization": " Let me check my calendar",
"translation": ""
}
}
]
],
"metadata": {
"remaining_time_transcription": 0.3,
"remaining_time_diarization": 2.1
}
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
// Update 2
{
"segments": [
{
"id": 1,
"speaker": 1,
"text": " you on this",
"translation": "Bonjour tout le monde",
"buffer": {
"transcription": "",
"diarization": " beautiful day",
"translation": ",comment"
}
},
]
}
// ==== Frontend ====
// <SPEAKER>1</SPEAKER>
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
```
### Silence Segments
### Silence Segment
Silence is represented with the speaker id = `-2`:
```jsonc
```json
{
"id": 5,
"speaker": -2,
"text": "",
"start": 10.5,
"end": 12.3
"start_speaker": "0:00:10",
"start": "0:00:10",
"end": "0:00:15",
"language": null,
"translation": "",
"buffer": {
"transcription": "",
"diarization": "",
"translation": ""
}
}
```
---
## Text Transcript Endpoint (`/text`)
The `/text` endpoint provides a simple, monospace text interface designed for:
- Easy copy/paste of transcripts
- Debugging and development
- Integration testing
Output uses text markers instead of HTML styling:
```
[METADATA transcription_lag=0.5s diarization_lag=1.2s]
[SPEAKER 1] 0:00:03 - 0:00:11 [LANG: en]
Hello world, how are you doing today?[DIAR_BUFFER] I'm doing fine[/DIAR_BUFFER]
[SILENCE 0:00:15 - 0:00:18]
[SPEAKER 2] 0:00:18 - 0:00:22 [LANG: en]
That's great to hear!
[TRANSLATION]C'est super à entendre![/TRANSLATION]
```
### Markers
| Marker | Description |
|--------|-------------|
| `[SPEAKER N]` | Speaker label with ID |
| `[SILENCE start - end]` | Silence segment |
| `[LANG: xx]` | Detected language code |
| `[DIAR_BUFFER]...[/DIAR_BUFFER]` | Text pending speaker assignment |
| `[TRANS_BUFFER]...[/TRANS_BUFFER]` | Text pending validation |
| `[TRANSLATION]...[/TRANSLATION]` | Translation content |
| `[METADATA ...]` | Lag/timing information |

View File

@@ -1,13 +1,73 @@
### Alignment between STT Tokens and Diarization Segments
# Alignment Principles
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
This document explains how transcription tokens are aligned with diarization (speaker identification) segments.
> `#` Is the split between the `t-1` prediction and `t` prediction.
---
## Token-by-Token Validation
When diarization is enabled, text is validated **token-by-token** rather than waiting for sentence boundaries. As soon as diarization covers a token's time range, that token is validated and assigned to the appropriate speaker.
### How It Works
1. **Transcription produces tokens** with timestamps (start, end)
2. **Diarization produces speaker segments** with timestamps
3. **For each token**: Check if diarization has caught up to that token's time
- If yes → Find speaker with maximum overlap, validate token
- If no → Keep token in "pending" (becomes diarization buffer)
```
Timeline: 0s -------- 5s -------- 10s -------- 15s
| | | |
Transcription: [Hello, how are you doing today?]
|_______|___|____|_____|_____|_____|
tok1 tok2 tok3 tok4 tok5 tok6
Diarization: [SPEAKER 1 ][SPEAKER 2 ]
|__________________|__________________|
0s 8s 15s
At time t when diarization covers up to 8s:
- Tokens 1-4 (0s-7s) → Validated as SPEAKER 1
- Tokens 5-6 (7s-10s) → In buffer (diarization hasn't caught up)
```
---
## Silence Handling
- **Short silences (< 2 seconds)**: Filtered out, not displayed
- **Significant silences (≥ 2 seconds)**: Displayed as silence segments with `speaker: -2`
- **Same speaker across gaps**: Segments are merged even if separated by short silences
```
Before filtering:
[SPK1 0:00-0:03] [SILENCE 0:03-0:04] [SPK1 0:04-0:08]
After filtering (silence < 2s):
[SPK1 0:00-0:08] ← Merged into single segment
```
---
## Buffer Types
| Buffer | Contains | Displayed When |
|--------|----------|----------------|
| `transcription` | Text awaiting validation (more context needed) | Always on last segment |
| `diarization` | Text awaiting speaker assignment | When diarization lags behind transcription |
| `translation` | Translation awaiting validation | When translation is enabled |
---
## Legacy: Punctuation-Based Splitting
The previous approach split segments at punctuation marks and aligned with diarization at those boundaries. This is now replaced by token-by-token validation for faster, more responsive results.
### Historical Examples (for reference)
Example of punctuation-based alignment:
## Example 1:
```text
punctuations_segments : __#_______.__________________!____
diarization_segments:
@@ -16,56 +76,6 @@ SPK2 # ___________________
-->
ALIGNED SPK1 __#_______.
ALIGNED SPK2 # __________________!____
t-1 output:
SPK1: __#
SPK2: NO
DIARIZATION BUFFER: NO
t output:
SPK1: __#__.
SPK2: __________________!____
DIARIZATION BUFFER: No
```
## Example 2:
```text
punctuations_segments : _____#__.___________
diarization_segments:
SPK1 ___ #
SPK2 __#______________
-->
ALIGNED SPK1 _____#__.
ALIGNED SPK2 # ___________
t-1 output:
SPK1: ___ #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: __#__.
SPK2: ___________
DIARIZATION BUFFER: No
```
## Example 3:
```text
punctuations_segments : ___.__#__________
diarization_segments:
SPK1 ______#__
SPK2 # ________
-->
ALIGNED SPK1 ___. #
ALIGNED SPK2 __#__________
t-1 output:
SPK1: ___. #
SPK2:
DIARIZATION BUFFER: __#
t output:
SPK1: #
SPK2: __#___________
DIARIZATION BUFFER: NO
```
With token-by-token validation, the alignment happens continuously rather than at punctuation boundaries.

View File

@@ -1,109 +0,0 @@
# Available Whisper model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
- `largev3turbo`: ~6GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
## Quick Decision Matrix
**Choose 600M**: Limited resources, close to 0 lag
**Choose 1.3B**: Quality matters
**Choose Transformers**: On Apple Silicon

View File

@@ -0,0 +1,106 @@
# Models and Model Paths
## Defaults
**Default Whisper Model**: `base`
When no model is specified, WhisperLiveKit uses the `base` model, which provides a good balance of speed and accuracy for most use cases.
**Default Model Cache Directory**: `~/.cache/whisper`
Models are automatically downloaded from OpenAI's model hub and cached in this directory. You can override this with `--model_cache_dir`.
**Default Translation Model**: `600M` (NLLB-200-distilled)
When translation is enabled, the 600M distilled NLLB model is used by default. This provides good quality with minimal resource usage.
**Default Translation Backend**: `transformers`
The translation backend defaults to Transformers. On Apple Silicon, this automatically uses MPS acceleration for better performance.
---
## Available Whisper model sizes:
| Available Model | Speed | Accuracy | Multilingual | Translation | Hardware Requirements | Best Use Case |
|--------------------|----------|-----------|--------------|-------------|----------------------|----------------------------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | ~1GB VRAM | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | ~1GB VRAM | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | ~2GB VRAM | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | ~5GB VRAM | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Good overall accuracy & language support |
| large-v3 | Slowest | Excellent | Yes | Yes | ~10GB VRAM | Best overall accuracy & language support |
| large-v3-turbo | Fast | Excellent | Yes | No | ~6GB VRAM | Fast, high-quality transcription |
### How to choose?
#### Language Support
- **English only**: Use `.en` (ex: `base.en`) models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
#### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
_______________________
# Custom Models:
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce hallucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignment heads are set to be all the heads of the last half layer of decoder.
_______________________
# Translation Models and Backend
**Language Support**: ~200 languages
## Distilled Model Sizes Available
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|-------|------|------------|-------------|-------------|---------|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
## Backend Performance
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|---------|---------------|--------------|--------------|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
| Transformers | Baseline | High | None |
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
**Metrics**:
- CTranslate2: 50-100+ tokens/sec
- Transformers: 10-30 tokens/sec
- Apple Silicon with MPS: Up to 2x faster than CTranslate2

View File

@@ -1,19 +0,0 @@
# Model Path Formats
The `--model-path` parameter accepts:
## File Path
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
## Hugging Face Repo ID
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

View File

@@ -1,6 +1,114 @@
# Supported Languages
# Transcription: Supported Language
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
WLK supports transcription in the following languages:
| ISO Code | Language Name |
|----------|---------------------|
| en | English |
| zh | Chinese |
| de | German |
| es | Spanish |
| ru | Russian |
| ko | Korean |
| fr | French |
| ja | Japanese |
| pt | Portuguese |
| tr | Turkish |
| pl | Polish |
| ca | Catalan |
| nl | Dutch |
| ar | Arabic |
| sv | Swedish |
| it | Italian |
| id | Indonesian |
| hi | Hindi |
| fi | Finnish |
| vi | Vietnamese |
| he | Hebrew |
| uk | Ukrainian |
| el | Greek |
| ms | Malay |
| cs | Czech |
| ro | Romanian |
| da | Danish |
| hu | Hungarian |
| ta | Tamil |
| no | Norwegian |
| th | Thai |
| ur | Urdu |
| hr | Croatian |
| bg | Bulgarian |
| lt | Lithuanian |
| la | Latin |
| mi | Maori |
| ml | Malayalam |
| cy | Welsh |
| sk | Slovak |
| te | Telugu |
| fa | Persian |
| lv | Latvian |
| bn | Bengali |
| sr | Serbian |
| az | Azerbaijani |
| sl | Slovenian |
| kn | Kannada |
| et | Estonian |
| mk | Macedonian |
| br | Breton |
| eu | Basque |
| is | Icelandic |
| hy | Armenian |
| ne | Nepali |
| mn | Mongolian |
| bs | Bosnian |
| kk | Kazakh |
| sq | Albanian |
| sw | Swahili |
| gl | Galician |
| mr | Marathi |
| pa | Punjabi |
| si | Sinhala |
| km | Khmer |
| sn | Shona |
| yo | Yoruba |
| so | Somali |
| af | Afrikaans |
| oc | Occitan |
| ka | Georgian |
| be | Belarusian |
| tg | Tajik |
| sd | Sindhi |
| gu | Gujarati |
| am | Amharic |
| yi | Yiddish |
| lo | Lao |
| uz | Uzbek |
| fo | Faroese |
| ht | Haitian Creole |
| ps | Pashto |
| tk | Turkmen |
| nn | Nynorsk |
| mt | Maltese |
| sa | Sanskrit |
| lb | Luxembourgish |
| my | Myanmar |
| bo | Tibetan |
| tl | Tagalog |
| mg | Malagasy |
| as | Assamese |
| tt | Tatar |
| haw | Hawaiian |
| ln | Lingala |
| ha | Hausa |
| ba | Bashkir |
| jw | Javanese |
| su | Sundanese |
| yue | Cantonese |
# Translation: Supported Languages
WLK supports translation into **201 languages** from the FLORES-200 dataset through the [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) translation system.
## How to Specify Languages

View File

@@ -40,4 +40,4 @@ This document introduce how to reuse the core components when you do **not** wan
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently. Just ensure `ffmpeg` is available.

140
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,140 @@
# Troubleshooting
## GPU drivers & cuDNN visibility
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
> Reported in issue #271 (Arch/CachyOS)
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
```bash
sudo pacman -S cuda cudnn
sudo ldconfig
```
2. **Make sure the shared objects are visible**:
```bash
ls /usr/lib/libcudnn*
```
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
```bash
python - <<'EOF'
import torch
print(torch.version.cuda)
EOF
nvcc --version
```
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
### Windows error: `Could not locate cudnn_ops64_9.dll`
> Reported in issue #286 (Conda on Windows)
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
1. Locate the DLL shipped with PyTorch, e.g.
```
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
```
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
3. Restart the shell before launching `wlk`.
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
---
## PyTorch / CTranslate2 GPU builds
### `Torch not compiled with CUDA enabled`
> Reported in issue #284
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
Install the GPU-enabled wheels that match your CUDA toolkit:
```bash
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
```
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
Validate with:
```python
import torch
print(torch.cuda.is_available(), torch.cuda.get_device_name())
```
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
> Follow-up in issue #284
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
```bash
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
```
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
3. Verify:
```python
import ctranslate2
print("CUDA devices:", ctranslate2.get_cuda_device_count())
print("CUDA compute types:", ctranslate2.get_supported_compute_types("cuda", 0))
```
**Note for aarch64 systems (e.g., NVIDIA DGX Spark):** Pre-built CUDA wheels may not be available for all CUDA versions on ARM architectures. If the wheel installation fails, you may need to compile CTranslate2 from source with CUDA support enabled.
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
---
## Hopper / Blackwell (`sm_121a`) systems
> Reported in issues #276 and #284 (NVIDIA DGX Spark)
CUDA 12.1a GPUs (e.g., NVIDIA GB10 on DGX Spark) ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual configuration.
### Error: `ptxas fatal : Value 'sm_121a' is not defined for option 'gpu-name'`
If you encounter this error after compiling CTranslate2 from source on aarch64 systems, Triton's bundled `ptxas` may not support the `sm_121a` architecture. The solution is to replace Triton's `ptxas` with the system's CUDA `ptxas`:
```bash
# Find your Python environment's Triton directory
python -c "import triton; import os; print(os.path.dirname(triton.__file__))"
# Copy the system ptxas to Triton's backend directory
# Replace <triton_path> with the output above
cp /usr/local/cuda/bin/ptxas <triton_path>/backends/nvidia/bin/ptxas
```
For example, in a virtual environment:
```bash
cp /usr/local/cuda/bin/ptxas ~/wlk/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
```
**Note:** On DGX Spark systems, CUDA is typically already in `PATH` (`/usr/local/cuda/bin`), so explicit `CUDA_HOME` and `PATH` exports may not be necessary. Verify with `which ptxas` before copying.
### Alternative: Environment variable approach
If the above doesn't work, you can try setting environment variables (though this may not resolve the `sm_121a` issue on all systems):
```bash
export CUDA_HOME="/usr/local/cuda-13.0"
export PATH="$CUDA_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
# Tell Triton where the new ptxas lives
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
# Force PyTorch to JIT kernels for all needed architectures
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
```
After applying the fix, restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
---
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.15"
version = "0.2.16.dev0"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -61,10 +61,10 @@ packages = [
"whisperlivekit.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.local_agreement",
"whisperlivekit.vad_models"
"whisperlivekit.silero_vad_models"
]
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -14,10 +14,10 @@ from typing import Dict, Tuple
import torch
from whisperlivekit.whisper import _convert_hf_state_dict
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
from whisperlivekit.whisper.model import ModelDimensions
from whisperlivekit.whisper.utils import exact_div
from whisperlivekit.whisper import _convert_hf_state_dict
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:

View File

@@ -5,16 +5,18 @@ import argparse
import base64
import gzip
import io
import math
import pathlib
import sys
import math
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from datasets import Audio as DatasetAudio, load_dataset
import soundfile as sf
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from datasets import Audio as DatasetAudio
from datasets import load_dataset
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"

View File

@@ -1,9 +1,10 @@
"""Copy core files from web directory to Chrome extension directory."""
import shutil
import os
import shutil
from pathlib import Path
def sync_extension_files():
web_dir = Path("whisperlivekit/web")

View File

@@ -1,7 +1,7 @@
from .audio_processor import AudioProcessor
from .core import TranscriptionEngine
from .parse_args import parse_args
from .web.web_interface import get_web_interface_html, get_inline_ui_html
from .web.web_interface import get_inline_ui_html, get_text_transcript_html, get_web_interface_html
__all__ = [
"TranscriptionEngine",
@@ -9,5 +9,6 @@ __all__ = [
"parse_args",
"get_web_interface_html",
"get_inline_ui_html",
"get_text_transcript_html",
"download_simulstreaming_backend",
]

View File

@@ -1,14 +1,20 @@
import asyncio
import numpy as np
from time import time
import logging
import traceback
from typing import Optional, Union, List, Any, AsyncGenerator
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from time import time
from typing import Any, AsyncGenerator, List, Optional, Union
import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -387,6 +393,10 @@ class AudioProcessor:
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
"""Format processing results for output."""
# Update intervals
ACTIVE_INTERVAL = 0.05 # 20 updates/sec during active transcription
SILENCE_INTERVAL = 0.5 # 2 updates/sec during silence
while True:
try:
if self._ffmpeg_error:
@@ -396,25 +406,35 @@ class AudioProcessor:
continue
self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
state = await self.get_current_state()
# Get transcription buffer text to pass to get_lines
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
# get_lines now returns segments with per-segment buffers
segments = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=bool(self.translation),
current_silence=self.current_silence
current_silence=self.current_silence,
buffer_transcription=buffer_transcription_text
)
state = await self.get_current_state()
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
response_status = "active_transcription"
if not lines and not buffer_transcription_text and not buffer_diarization_text:
# Check if there's any content (segments with text or buffers)
has_active_content = any(
seg.buffer and (seg.buffer.transcription or seg.buffer.diarization)
for seg in segments if not seg.is_silence()
)
has_any_content = any(
seg.text or (seg.buffer and (seg.buffer.transcription or seg.buffer.diarization))
for seg in segments if not seg.is_silence()
)
if not segments or not has_any_content:
response_status = "no_audio_detected"
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
segments=segments,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
@@ -428,7 +448,15 @@ class AudioProcessor:
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
await asyncio.sleep(0.05)
# Throttle updates during silence: use slower interval when in silence mode
# with no pending buffers (nothing actively being processed)
is_in_silence = self.current_silence is not None
has_pending_work = has_active_content or state.remaining_time_transcription > 0.5
if is_in_silence and not has_pending_work:
await asyncio.sleep(SILENCE_INTERVAL)
else:
await asyncio.sleep(ACTIVE_INTERVAL)
except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
@@ -603,16 +631,16 @@ class AudioProcessor:
res = self.vac(pcm_array)
if res is not None:
silence_detected = res.get("end", 0) > res.get("start", 0)
if silence_detected and not self.current_silence:
if "start" in res and self.current_silence:
await self._end_silence()
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end")
)
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence()
elif self.current_silence:
await self._end_silence()
if not self.current_silence:
await self._enqueue_active_audio(pcm_array)

View File

@@ -1,10 +1,13 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
get_inline_ui_html, get_text_transcript_html, parse_args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
@@ -36,6 +39,12 @@ async def get():
return HTMLResponse(get_inline_ui_html())
@app.get("/text")
async def get_text():
"""Simple text-based transcript view for easy copy/paste."""
return HTMLResponse(get_text_transcript_html())
async def handle_websocket_results(websocket, results_generator):
"""Consumes results from the audio processor and sends them via WebSocket."""
try:

View File

@@ -1,9 +1,11 @@
import logging
import sys
from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
from whisperlivekit.local_agreement.whisper_online import backend_factory
from whisperlivekit.simul_whisper import SimulStreamingASR
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
from argparse import Namespace
import sys
import logging
def update_with_kwargs(_dict, kwargs):
_dict.update({
@@ -57,6 +59,7 @@ class TranscriptionEngine:
"model_cache_dir": None,
"model_dir": None,
"model_path": None,
"lora_path": None,
"lan": "auto",
"direct_english_translation": False,
}
@@ -80,6 +83,7 @@ class TranscriptionEngine:
if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad
# Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx)
@@ -100,7 +104,6 @@ class TranscriptionEngine:
"init_prompt": None,
"static_init_prompt": None,
"max_context_tokens": None,
"preload_model_count": 1,
}
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
@@ -135,7 +138,8 @@ class TranscriptionEngine:
if self.args.diarization:
if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization
from whisperlivekit.diarization.diart_backend import \
DiartDiarization
diart_params = {
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
@@ -146,7 +150,8 @@ class TranscriptionEngine:
**diart_params
)
elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarization
self.diarization_model = SortformerDiarization()
self.translation_model = None
@@ -182,7 +187,8 @@ def online_diarization_factory(args, diarization_backend):
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online

View File

@@ -1,20 +1,20 @@
import asyncio
import logging
import re
import threading
import numpy as np
import logging
import time
from queue import SimpleQueue, Empty
from queue import Empty, SimpleQueue
from typing import Any, List, Tuple
import diart.models as m
import numpy as np
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference
from diart.sources import AudioSource
from whisperlivekit.timed_objects import SpeakerSegment
from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from diart.sources import AudioSource, MicrophoneAudioSource
from pyannote.core import Annotation
import diart.models as m
from rx.core import Observer
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)

View File

@@ -1,11 +1,12 @@
import numpy as np
import torch
import logging
import threading
import time
import wave
from queue import Empty, SimpleQueue
from typing import List, Optional
from queue import SimpleQueue, Empty
import numpy as np
import torch
from whisperlivekit.timed_objects import SpeakerSegment
@@ -295,6 +296,7 @@ def extract_number(s: str) -> int:
if __name__ == '__main__':
import asyncio
import librosa
async def main():

View File

@@ -1,8 +1,8 @@
import asyncio
import contextlib
import logging
from enum import Enum
from typing import Optional, Callable
import contextlib
from typing import Callable, Optional
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

View File

@@ -1,21 +1,25 @@
import sys
import logging
import io
import soundfile as sf
import logging
import math
import sys
from typing import List
import numpy as np
import soundfile as sf
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
logger = logging.getLogger(__name__)
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when needed)
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
self.logfile = logfile
self.transcribe_kargs = {}
self.lora_path = lora_path
if lan == "auto":
self.original_language = None
else:
@@ -44,24 +48,23 @@ class WhisperASR(ASRBase):
sep = " "
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from whisperlivekit.whisper import load_model as load_model
from whisperlivekit.whisper import load_model as load_whisper_model
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)
resolved_path = resolve_model_path(model_dir)
if resolved_path.is_dir():
pytorch_path, _, _ = model_path_and_type(resolved_path)
if pytorch_path is None:
model_info = detect_model_format(resolved_path)
if not model_info.has_pytorch:
raise FileNotFoundError(
f"No supported PyTorch checkpoint found under {resolved_path}"
)
resolved_path = pytorch_path
)
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
return load_model(str(resolved_path))
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
if model_size is None:
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
return load_model(model_size, download_root=cache_dir)
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
def transcribe(self, audio, init_prompt=""):
options = dict(self.transcribe_kargs)
@@ -165,8 +168,8 @@ class MLXWhisper(ASRBase):
sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx
from mlx_whisper.transcribe import ModelHolder, transcribe
if model_dir is not None:
resolved_path = resolve_model_path(model_dir)

View File

@@ -1,7 +1,9 @@
import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import sys
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__)

View File

@@ -1,18 +1,19 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging
import platform
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
import sys
import time
from functools import lru_cache
import librosa
import numpy as np
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.warmup import warmup_asr
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
logger = logging.getLogger(__name__)
@@ -76,6 +77,7 @@ def backend_factory(
model_cache_dir,
model_dir,
model_path,
lora_path,
direct_english_translation,
buffer_trimming,
buffer_trimming_sec,
@@ -86,16 +88,20 @@ def backend_factory(
backend_choice = backend
custom_reference = model_path or model_dir
resolved_root = None
pytorch_checkpoint = None
has_mlx_weights = False
has_fw_weights = False
has_pytorch = False
if custom_reference:
resolved_root = resolve_model_path(custom_reference)
if resolved_root.is_dir():
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
model_info = detect_model_format(resolved_root)
has_mlx_weights = model_info.compatible_whisper_mlx
has_fw_weights = model_info.compatible_faster_whisper
has_pytorch = model_info.has_pytorch
else:
pytorch_checkpoint = resolved_root
# Single file provided
has_pytorch = True
if backend_choice == "openai-api":
logger.debug("Using OpenAI API.")
@@ -120,8 +126,8 @@ def backend_factory(
model_override = str(resolved_root) if resolved_root is not None else None
else:
asr_cls = WhisperASR
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
if custom_reference and model_override is None:
model_override = str(resolved_root) if resolved_root is not None else None
if custom_reference and not has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
)
@@ -133,6 +139,7 @@ def backend_factory(
lan=lan,
cache_dir=model_cache_dir,
model_dir=model_override,
lora_path=lora_path if backend_choice == "whisper" else None,
)
e = time.time()
logger.info(f"done. It took {round(e-t,2)} seconds.")

View File

@@ -1,49 +1,195 @@
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
@dataclass
class ModelInfo:
"""Information about detected model format and files in a directory."""
path: Optional[Path] = None
pytorch_files: List[Path] = field(default_factory=list)
compatible_whisper_mlx: bool = False
compatible_faster_whisper: bool = False
@property
def has_pytorch(self) -> bool:
return len(self.pytorch_files) > 0
@property
def is_sharded(self) -> bool:
return len(self.pytorch_files) > 1
@property
def primary_pytorch_file(self) -> Optional[Path]:
"""Return the primary PyTorch file (or first shard for sharded models)."""
if not self.pytorch_files:
return None
return self.pytorch_files[0]
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
"""
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
CTranslate2 models have specific companion files that distinguish them
from PyTorch .bin files.
"""
n_indicators = 0
for indicator in CT2_INDICATOR_FILES: #test 1
if (directory / indicator).exists():
n_indicators += 1
if n_indicators == 0:
return False
config_path = directory / "config.json" #test 2
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
if config.get("model_type") == "whisper": #test 2
return False
except (json.JSONDecodeError, IOError):
pass
return True
def _collect_pytorch_files(directory: Path) -> List[Path]:
"""
Collect all PyTorch checkpoint files from a directory.
Handles:
- Single files: model.safetensors, pytorch_model.bin, *.pt
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
- Index-based sharded models (reads index file to find shards)
Returns files sorted appropriately (shards in order, or single file).
"""
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
index_path = directory / index_name
if index_path.exists():
try:
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
if weight_map:
shard_names = sorted(set(weight_map.values()))
shards = [directory / name for name in shard_names if (directory / name).exists()]
if shards:
return shards
except (json.JSONDecodeError, IOError):
pass
sharded_groups = {}
single_files = {}
for file in directory.iterdir():
if not file.is_file():
continue
filename = file.name
suffix = file.suffix.lower()
if filename.startswith("adapter_"):
continue
match = SHARDED_PATTERN.match(filename)
if match:
base_name, shard_idx, total_shards, ext = match.groups()
key = (base_name, ext, int(total_shards))
if key not in sharded_groups:
sharded_groups[key] = []
sharded_groups[key].append((int(shard_idx), file))
continue
if filename == "model.safetensors":
single_files[0] = file # Highest priority
elif filename == "pytorch_model.bin":
single_files[1] = file
elif suffix == ".pt":
single_files[2] = file
elif suffix == ".safetensors" and not filename.startswith("adapter"):
single_files[3] = file
for (base_name, ext, total_shards), shards in sharded_groups.items():
if len(shards) == total_shards:
return [path for _, path in sorted(shards)]
for priority in sorted(single_files.keys()):
return [single_files[priority]]
return []
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
"""
Detect the model format in a given path.
This function analyzes a file or directory to determine:
- What PyTorch checkpoint files are available (including sharded models)
- Whether the directory contains MLX Whisper weights
- Whether the directory contains Faster-Whisper (CTranslate2) weights
Args:
model_path: Path to a model file or directory
Returns:
ModelInfo with detected format information
"""
path = Path(model_path)
info = ModelInfo(path=path)
if path.is_file():
suffix = path.suffix.lower()
if suffix in {".pt", ".safetensors", ".bin"}:
info.pytorch_files = [path]
return info
if not path.is_dir():
return info
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
if filename in MLX_WHISPER_MARKERS:
info.compatible_whisper_mlx = True
if filename in FASTER_WHISPER_MARKERS:
if _is_ct2_model_bin(path, filename):
info.compatible_faster_whisper = True
info.pytorch_files = _collect_pytorch_files(path)
return info
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
"""
Inspect the provided path and determine which model formats are available.
This is a compatibility wrapper around detect_model_format().
Returns:
pytorch_path: Path to a PyTorch checkpoint (if present).
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
compatible_whisper_mlx: True if MLX weights exist in this folder.
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
"""
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pytorch_path: Optional[Path] = None
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
pytorch_path = path
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
if path.is_dir():
for file in path.iterdir():
if not file.is_file():
continue
filename = file.name.lower()
suffix = file.suffix.lower()
if filename in {"weights.npz", "weights.safetensors"}:
compatible_whisper_mlx = True
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
compatible_faster_whisper = True
elif suffix in {".pt", ".safetensors"}:
pytorch_path = file
elif filename == "pytorch_model.bin":
pytorch_path = file
if pytorch_path is None:
fallback = path / "pytorch_model.bin"
if fallback.exists():
pytorch_path = fallback
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
info = detect_model_format(model_path)
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
def resolve_model_path(model_path: Union[str, Path]) -> Path:
@@ -59,7 +205,7 @@ def resolve_model_path(model_path: Union[str, Path]) -> Path:
try:
from huggingface_hub import snapshot_download
except ImportError as exc: # pragma: no cover - optional dependency guard
except ImportError as exc:
raise FileNotFoundError(
f"Model path '{model_path}' does not exist locally and huggingface_hub "
"is not installed to download it."

View File

@@ -1,6 +1,7 @@
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument(
@@ -105,6 +106,13 @@ def parse_args():
default=None,
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
)
parser.add_argument(
"--lora-path",
type=str,
default=None,
dest="lora_path",
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
)
parser.add_argument(
"--lan",
"--language",
@@ -295,14 +303,6 @@ def parse_args():
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
)
simulstreaming_group.add_argument(
"--preload-model-count",
type=int,
default=1,
dest="preload_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
simulstreaming_group.add_argument(
"--nllb-backend",
type=str,

View File

@@ -1,8 +1,9 @@
import torch
import numpy as np
import warnings
from pathlib import Path
import numpy as np
import torch
"""
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
"""
@@ -123,7 +124,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'vad_models'
data_dir = current_dir / 'silero_vad_models'
if onnx:
if opset_version == 16:
@@ -138,7 +139,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
)
else:
model_path = Path(model_path)
@@ -276,8 +277,10 @@ class FixedVADIterator(VADIterator):
elif r is not None:
if "end" in r:
ret["end"] = r["end"]
if "start" in r and "end" in ret:
del ret["end"]
if "start" in r:
ret["start"] = r["start"]
if "end" in ret:
del ret["end"]
return ret if ret != {} else None

View File

@@ -1,31 +1,30 @@
import sys
import numpy as np
import gc
import logging
from typing import List, Tuple, Optional
import os
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
import sys
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
import os
import gc
from pathlib import Path
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
else:
mlx_model_mapping = {}
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
@@ -50,20 +49,19 @@ class SimulStreamingOnlineProcessor:
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
self.load_new_alignatt_instance()
#can be moved
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def load_new_backend(self):
model = self.asr.get_new_model_instance()
def load_new_alignatt_instance(self):
"""Initialize AlignAtt decoder using the shared model."""
self.model = AlignAtt(
cfg=self.asr.cfg,
loaded_model=model,
loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
)
def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True)
@@ -71,7 +69,10 @@ class SimulStreamingOnlineProcessor:
def end_silence(self, silence_duration, offset):
"""
If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
"""
self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
@@ -84,21 +85,20 @@ class SimulStreamingOnlineProcessor:
self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True)
self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.global_time_offset = change_speaker.start
"""Handle speaker change event."""
self.process_iter(is_last=True)
self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.model.global_time_offset = change_speaker.start
def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
@@ -112,15 +112,17 @@ class SimulStreamingOnlineProcessor:
"""
try:
timestamped_words = self.model.infer(is_last=is_last)
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
if not timestamped_words:
return [], self.end
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
@@ -136,12 +138,8 @@ class SimulStreamingOnlineProcessor:
logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self):
# free the model and add a new model to stack.
# del self.model
gc.collect()
torch.cuda.empty_cache()
# self.asr.new_model_to_stack()
self.model.remove_hooks()
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
@@ -161,34 +159,23 @@ class SimulStreamingASR():
self._resolved_model_path = None
self.encoder_backend = "whisper"
preferred_backend = getattr(self, "backend", "auto")
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
compatible_whisper_mlx, compatible_faster_whisper = True, True
if self.model_path:
resolved_model_path = resolve_model_path(self.model_path)
self._resolved_model_path = resolved_model_path
self.model_path = str(resolved_model_path)
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
if self.pytorch_path:
self.model_name = self.pytorch_path.stem
else:
self.model_name = Path(self.model_path).stem
model_info = detect_model_format(resolved_model_path)
compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper
if not model_info.has_pytorch:
raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
)
)
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_name = self.model_size
else:
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
@@ -226,10 +213,7 @@ class SimulStreamingASR():
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None
if self.encoder_backend == "mlx-whisper":
print('Simulstreaming will use MLX whisper to increase encoding speed.')
@@ -253,8 +237,7 @@ class SimulStreamingASR():
device='auto',
compute_type='auto',
)
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.shared_model = self.load_model()
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
@@ -298,16 +281,19 @@ class SimulStreamingASR():
return True
def load_model(self):
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model(
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
name=model_ref,
download_root=None,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads
)
custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path,
)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
if self.fast_encoder:
temp_model = AlignAtt(
cfg=self.cfg,
loaded_model=whisper_model,
@@ -315,27 +301,9 @@ class SimulStreamingASR():
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self):
"""Set up translation task."""

View File

@@ -1,17 +1,32 @@
from torch import Tensor
from whisperlivekit.whisper.decoding import PyTorchInference
# extention of PyTorchInference for beam search
class BeamPyTorchInference(PyTorchInference):
def _kv_modules(self):
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
return key_modules + value_modules
class BeamPyTorchInference(PyTorchInference):
"""Extension of PyTorchInference for beam search with cross-attention support."""
def _kv_cache_ids(self):
"""Get cache_id strings for self-attention key/value modules."""
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
return key_ids + value_ids
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for module_cache_id in self._kv_modules():
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
from torch import Tensor
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
for cache_id in self._kv_cache_ids():
if cache_id in self.kv_cache:
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
def logits(
self,
tokens: Tensor,
audio_features: Tensor,
return_cross_attn: bool = False,
):
"""Get logits, optionally returning cross-attention weights."""
return self.model.decoder(
tokens, audio_features,
kv_cache=self.kv_cache,
return_cross_attn=return_cross_attn,
)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class AlignAttConfig():
eval_data_path: str = "tmp"

View File

@@ -0,0 +1,80 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
@dataclass
class DecoderState:
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[torch.Tensor] = field(default_factory=list)
initial_tokens: Optional[torch.Tensor] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[torch.Tensor] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
CIFLinear: Optional[torch.nn.Module] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens_fn: Any = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.kv_cache = {}
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
"""
Reset transient state for a new segment.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = {}
self.first_timestamp = None

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {

View File

@@ -1,33 +1,36 @@
import os
import logging
import torch
import torch.nn.functional as F
import numpy as np
from whisperlivekit.whisper import DecodingOptions, tokenizer
from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from whisperlivekit.whisper.timing import median_filter
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from time import time
from .token_buffer import TokenBuffer
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
TOKENS_PER_SECOND,
log_mel_spectrogram, pad_or_trim)
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter
from ..timed_objects import PUNCTUATION_MARKS
from .beam import BeamPyTorchInference
from .config import AlignAttConfig
from .decoder_state import DecoderState
from .eow_detection import fire_at_boundary, load_cif
from .token_buffer import TokenBuffer
DEC_PAD = 50257
logger = logging.getLogger(__name__)
if mlx_backend_available():
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.audio import \
log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
if faster_backend_available():
@@ -52,6 +55,30 @@ def load_coreml_encoder():
class AlignAtt:
"""
Alignment-based Attention decoder for SimulStreaming.
This class is now hookless - the model can be shared across multiple
sessions, with each session maintaining its own DecoderState.
"""
# Property accessors for backward compatibility
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
@@ -59,8 +86,7 @@ class AlignAtt:
mlx_encoder=None,
fw_encoder=None,
) -> None:
self.log_segments = 0
# Shared model reference (can be shared across sessions)
self.model = loaded_model
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
@@ -74,119 +100,89 @@ class AlignAtt:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions(
language = cfg.language,
without_timestamps = True,
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
self.l_hooks = []
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device)
# install hooks to access encoder-decoder attention
self.dec_attns = []
def layer_hook(module, net_input, net_output):
# net_output[1]: B*num_head*token_len*audio_len
t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks:
hook = b.cross_attn.register_forward_hook(layer_hook)
self.l_hooks.append(hook)
self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
# save as-is, for the first token or cross attention
self.kv_cache[module.cache_id] = net_output
else:
x = self.kv_cache[module.cache_id]
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks):
hooks = [
b.attn.key.register_forward_hook(kv_hook),
b.attn.value.register_forward_hook(kv_hook),
b.cross_attn.key.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook),
]
self.l_hooks.extend(hooks)
self.align_source = {}
self.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.align_source.get(layer_rank, [])
heads.append((self.num_align_heads, head_id.item()))
self.align_source[layer_rank] = heads
self.num_align_heads += 1
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
# self.tokenizer.eot
self.tokenizer.no_timestamps, # added by DM
] + list(self.tokenizer.all_language_tokens) # added by DM
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.first_timestamp = None
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = DecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
# Create tokenizer
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
# Timing state
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
# CIF helpers for end-of-word boundary detection
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device
)
# Build alignment source mapping from model's alignment_heads
self.state.align_source = {}
self.state.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id.item()))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
# Build suppress tokens function
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
# Initialize tokens
self.init_tokens()
self.init_context()
# decoder type: greedy or beam
# Set up decoder type
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using greedy decoder")
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
self.decoder_type = "greedy"
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
self.decoder_type = "beam"
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
self.inference.kv_cache = self.kv_cache
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# Tokens to carry over to next chunk for incomplete UTF-8 characters
self.pending_incomplete_tokens = []
def remove_hooks(self):
for hook in self.l_hooks:
hook.remove()
logger.info("Using beam decoder")
self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
self.state.inference.kv_cache = self.state.kv_cache
self.state.token_decoder = BeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def warmup(self, audio):
try:
@@ -204,96 +200,100 @@ class AlignAtt:
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
kw = {'tokenizer': self.tokenizer,
'device': self.model.device,
'prefix_token_ids': [self.tokenizer.sot_prev]}
self.context = TokenBuffer.empty(**kw)
self.state.context = TokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}")
logger.debug(f"init tokens, {len(self.state.segments)}")
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.state.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = []
logger.debug(f"init tokens after, {len(self.segments)}")
self.tokens = [self.initial_tokens]
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
logger.info(f"Context text: {self.context.as_text()}")
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
l = sum(t.shape[1] for t in self.tokens) + c
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.context.trim_words(after=after)
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
logger.info(f"Context after trim: {self.context.text} (len: {l})")
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
if self.cfg.decoder_type == "greedy":
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def logits(
self,
tokens: torch.Tensor,
audio_features: torch.Tensor,
return_cross_attn: bool = False
):
"""Get logits from decoder, optionally returning cross-attention weights."""
if self.state.decoder_type == "greedy":
return self.model.decoder(
tokens, audio_features,
kv_cache=self.state.kv_cache,
return_cross_attn=return_cross_attn
)
else:
logger.debug(f"Logits shape: {tokens.shape}")
logit = self.inference.logits(tokens, audio_features)
return logit
return self.state.inference.logits(
tokens, audio_features,
return_cross_attn=return_cross_attn
)
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment:")
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
# self.detected_language = None
self.cumulative_time_offset = 0.0
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
self.segments = self.segments[-2:]
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.segments = []
self.log_segments += 1
self.pending_incomplete_tokens = []
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True
if self.never_fire: return False
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
def _current_tokens(self):
toks = self.tokens
toks = self.state.tokens
# very first infer: duplicate start of seq to beam_size
if toks[0].shape[0] == 1:
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
if not self.context.is_empty():
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
if not self.state.context.is_empty():
context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
toks = [context_toks] + toks
# make it one tensor
@@ -313,7 +313,7 @@ class AlignAtt:
### audio buffer
def segments_len(self):
segments_len = sum(s.shape[0] for s in self.segments) / 16000
segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
return segments_len
def _apply_minseglen(self):
@@ -326,42 +326,36 @@ class AlignAtt:
def insert_audio(self, segment=None):
if segment is not None:
self.segments.append(segment)
self.state.segments.append(segment)
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.cumulative_time_offset += removed_len # Track cumulative time removed
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:].tolist())
self.tokens = [self.initial_tokens] + self.tokens[2:]
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len # Track cumulative time removed
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache.
It must be called every time after generation with the model.'''
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
@torch.no_grad()
def lang_id(self, encoder_features):
"""Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
This code is trimmed and copy-pasted from whisper.decoding.detect_language.
"""
# forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
# Note: don't use kv_cache for language detection
logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens
@@ -391,19 +385,19 @@ class AlignAtt:
@torch.no_grad()
def infer(self, is_last=False):
new_segment = True
if len(self.segments) == 0:
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
input_segments = torch.cat(self.state.segments, dim=0)
return []
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
input_segments = torch.cat(self.segments, dim=0)
if len(self.state.segments) > 1:
input_segments = torch.cat(self.state.segments, dim=0)
else:
input_segments = self.segments[0]
input_segments = self.state.segments[0]
beg_encode = time()
if self.use_mlcore:
@@ -457,18 +451,18 @@ class AlignAtt:
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
seconds_since_start = self.segments_len() - self.first_timestamp
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.detected_language = top_lan
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context()
@@ -488,92 +482,90 @@ class AlignAtt:
l_absolute_timestamps = []
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding] # Discard all new tokens
break
if new_segment:
tokens_for_logits = current_tokens
else:
# only need to use the last token except in the first forward pass
tokens_for_logits = current_tokens[:,-1:]
tokens_for_logits = current_tokens[:, -1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
# Get logits and cross-attention weights from decoder
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
accumulated_cross_attns.append(cross_attns)
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
logits = logits[:, -1, :] # logits for the last token
# supress blank tokens only at the beginning of the segment
# suppress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.suppress_tokens(logits)
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
self.state.suppress_tokens_fn(logits)
current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
align_heads_in_layer = self.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0)
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# Process accumulated cross-attention weights for alignment
attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames.tolist()
]
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0])
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# # stripping the last token, the eot
# stripping the last token, the eot
current_tokens = current_tokens[:, :-1]
break
# for some rare cases where the attention fails
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
# TODO: check this
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
logger.debug("ommit rewinding from special tokens")
self.last_attend_frame = most_attended_frame
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(
f"[rewind detected] current attention pos: {most_attended_frame}, "
f"last attention pos: {self.last_attend_frame}; omit this segment")
self.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
f"last attention pos: {self.state.last_attend_frame}; omit this segment")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.last_attend_frame = most_attended_frame
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
@@ -593,12 +585,12 @@ class AlignAtt:
tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any
if self.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last: #or punctuation_stop:
if fire_detected or is_last:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
@@ -609,20 +601,18 @@ class AlignAtt:
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device,
)
self.tokens.append(new_tokens)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0]
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
@@ -641,20 +631,89 @@ class AlignAtt:
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text= word,
speaker=self.speaker,
detected_language=self.detected_language
).with_offset(
self.global_time_offset
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(
self.state.global_time_offset
)
timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk
self.pending_incomplete_tokens = []
# Hold incomplete tokens for next chunk (with limit to prevent hallucination accumulation)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10 # Real incomplete UTF-8 chars are at most a few tokens
if split_words and replacement_char in split_words[-1]:
self.pending_incomplete_tokens = split_tokens[-1]
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk")
else:
logger.warning(f"[UTF-8 Fix] Skipping {len(split_tokens[-1])} tokens (exceeds limit of {MAX_PENDING_TOKENS}, likely hallucination)")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[torch.Tensor],
content_mel_len: int
) -> torch.Tensor:
"""
Process cross-attention weights from decoder layers for alignment.
Args:
cross_attns: List of cross-attention tensors from each decoder layer.
Each tensor has shape (batch, n_head, seq_len, audio_len)
content_mel_len: Length of actual audio content in mel frames
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = len(self.model.decoder.blocks)
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
layer_rank = idx % num_decoder_layers
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = F.softmax(attn_mat, dim=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
# (n_head, seq_len, audio_len) when squeezed
if attn_mat.dim() == 4:
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
else:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0) # (1, seq_len, audio_len)
else:
# attn_mat: (batch, n_head, seq_len, audio_len)
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
tmp.append(t)
if not tmp:
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
return attn_of_alignment_heads

View File

@@ -1,5 +1,8 @@
import torch
import sys
import torch
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):

View File

@@ -1,261 +0,0 @@
"""
ALPHA. results are not great yet
To replace `whisperlivekit.silero_vad_iterator import FixedVADIterator`
by `from whisperlivekit.ten_vad_iterator import TenVADIterator`
Use self.vac = TenVADIterator() instead of self.vac = FixedVADIterator(models.vac_model)
"""
import numpy as np
from ten_vad import TenVad
class TenVADIterator:
def __init__(self,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30):
self.vad = TenVad()
self.threshold = threshold
self.sampling_rate = sampling_rate
self.min_silence_duration_ms = min_silence_duration_ms
self.speech_pad_ms = speech_pad_ms
self.min_silence_samples = int(sampling_rate * min_silence_duration_ms / 1000)
self.speech_pad_samples = int(sampling_rate * speech_pad_ms / 1000)
self.reset_states()
def reset_states(self):
self.triggered = False
self.temp_end = 0
self.current_sample = 0
self.buffer = np.array([], dtype=np.float32)
def __call__(self, x, return_seconds=False):
if not isinstance(x, np.ndarray):
x = np.array(x, dtype=np.float32)
self.buffer = np.append(self.buffer, x)
chunk_size = 256
ret = None
while len(self.buffer) >= chunk_size:
chunk = self.buffer[:chunk_size].astype(np.int16)
self.buffer = self.buffer[chunk_size:]
window_size_samples = len(chunk)
self.current_sample += window_size_samples
speech_prob, speech_flag = self.vad.process(chunk)
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
result = {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if ret is None:
ret = result
elif "end" in ret:
ret = result
else:
ret.update(result)
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
continue
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
result = {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
if ret is None:
ret = result
else:
ret.update(result)
return ret if ret != {} else None
def test_on_record_wav():
import os
from pathlib import Path
audio_file = Path("record.wav")
if not audio_file.exists():
return
import soundfile as sf
audio_data, sample_rate = sf.read(str(audio_file), dtype='float32')
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
vad = TenVADIterator(
threshold=0.5,
sampling_rate=sample_rate,
min_silence_duration_ms=100,
speech_pad_ms=30
)
chunk_size = 1024
speech_segments = []
current_segment = None
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i+chunk_size]
if chunk.dtype != np.int16:
chunk_int16 = (chunk * 32767.0).astype(np.int16)
else:
chunk_int16 = chunk
result = vad(chunk_int16, return_seconds=True)
if result is not None:
if 'start' in result:
current_segment = {'start': result['start'], 'end': None}
print(f"Speech start detected at {result['start']:.2f}s")
elif 'end' in result:
if current_segment:
current_segment['end'] = result['end']
duration = current_segment['end'] - current_segment['start']
speech_segments.append(current_segment)
print(f"Speech end detected at {result['end']:.2f}s (duration: {duration:.2f}s)")
current_segment = None
else:
print(f"Speech end detected at {result['end']:.2f}s (no corresponding start)")
if current_segment and current_segment['end'] is None:
current_segment['end'] = len(audio_data) / sample_rate
speech_segments.append(current_segment)
print(f"End of file (last segment at {current_segment['start']:.2f}s)")
print("-" * 60)
print(f"\nSummary:")
print(f"Number of speech segments detected: {len(speech_segments)}")
if speech_segments:
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
total_time = len(audio_data) / sample_rate
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech time: {total_speech_time:.2f}s")
print(f"Total file time: {total_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"\nDetected segments:")
for i, seg in enumerate(speech_segments, 1):
print(f" {i}. {seg['start']:.2f}s - {seg['end']:.2f}s (duration: {seg['end'] - seg['start']:.2f}s)")
else:
print("No speech segments detected")
print("\n" + "=" * 60)
print("Extracting silence segments...")
silence_segments = []
total_time = len(audio_data) / sample_rate
if not speech_segments:
silence_segments = [{'start': 0.0, 'end': total_time}]
else:
if speech_segments[0]['start'] > 0:
silence_segments.append({'start': 0.0, 'end': speech_segments[0]['start']})
for i in range(len(speech_segments) - 1):
silence_start = speech_segments[i]['end']
silence_end = speech_segments[i + 1]['start']
if silence_end > silence_start:
silence_segments.append({'start': silence_start, 'end': silence_end})
if speech_segments[-1]['end'] < total_time:
silence_segments.append({'start': speech_segments[-1]['end'], 'end': total_time})
silence_audio = np.array([], dtype=audio_data.dtype)
for seg in silence_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
silence_audio = np.concatenate([silence_audio, audio_data[start_sample:end_sample]])
if len(silence_audio) > 0:
output_file = "record_silence_only.wav"
try:
import soundfile as sf
sf.write(output_file, silence_audio, sample_rate)
print(f"Silence file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if silence_audio.dtype == np.float32:
silence_audio_int16 = (silence_audio * 32767.0).astype(np.int16)
else:
silence_audio_int16 = silence_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, silence_audio_int16)
print(f"Silence file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_silence_time = sum(seg['end'] - seg['start'] for seg in silence_segments)
silence_ratio = (total_silence_time / total_time) * 100
print(f"Total silence duration: {total_silence_time:.2f}s")
print(f"Silence ratio: {silence_ratio:.1f}%")
print(f"Number of silence segments: {len(silence_segments)}")
print(f"\nYou can listen to {output_file} to verify that only silences are present.")
else:
print("No silence segments found (file entirely speech)")
print("\n" + "=" * 60)
print("Extracting speech segments...")
if speech_segments:
speech_audio = np.array([], dtype=audio_data.dtype)
for seg in speech_segments:
start_sample = int(seg['start'] * sample_rate)
end_sample = int(seg['end'] * sample_rate)
start_sample = max(0, min(start_sample, len(audio_data)))
end_sample = max(0, min(end_sample, len(audio_data)))
if end_sample > start_sample:
speech_audio = np.concatenate([speech_audio, audio_data[start_sample:end_sample]])
if len(speech_audio) > 0:
output_file = "record_speech_only.wav"
try:
import soundfile as sf
sf.write(output_file, speech_audio, sample_rate)
print(f"Speech file saved: {output_file}")
except ImportError:
try:
from scipy.io import wavfile
if speech_audio.dtype == np.float32:
speech_audio_int16 = (speech_audio * 32767.0).astype(np.int16)
else:
speech_audio_int16 = speech_audio.astype(np.int16)
wavfile.write(output_file, sample_rate, speech_audio_int16)
print(f"Speech file saved: {output_file}")
except ImportError:
print("Unable to save: soundfile or scipy required")
total_speech_time = sum(seg['end'] - seg['start'] for seg in speech_segments)
speech_ratio = (total_speech_time / total_time) * 100
print(f"Total speech duration: {total_speech_time:.2f}s")
print(f"Speech ratio: {speech_ratio:.1f}%")
print(f"Number of speech segments: {len(speech_segments)}")
print(f"\nYou can listen to {output_file} to verify that only speech segments are present.")
else:
print("No speech audio to extract")
else:
print("No speech segments found (file entirely silence)")
if __name__ == "__main__":
test_on_record_wav()

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from typing import Optional, List, Union, Dict, Any
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -107,6 +107,21 @@ class Silence():
return True
@dataclass
class SegmentBuffer:
"""Per-segment buffer for ephemeral/unvalidated content."""
transcription: str = ''
diarization: str = ''
translation: str = ''
def to_dict(self) -> Dict[str, str]:
return {
'transcription': self.transcription,
'diarization': self.diarization,
'translation': self.translation
}
@dataclass
class Segment(TimedText):
"""Generic contiguous span built from tokens or silence markers."""
@@ -114,11 +129,18 @@ class Segment(TimedText):
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
id: Optional[int] = None
start_speaker: Optional[float] = None
tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None
buffer: Optional[SegmentBuffer] = None
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence: bool = False
is_silence: bool = False,
segment_id: Optional[int] = None
) -> Optional["Segment"]:
"""Return a normalized segment representing the provided tokens."""
if not tokens:
@@ -131,7 +153,9 @@ class Segment(TimedText):
start=start_token.start,
end=end_token.end,
text=None,
speaker=-2
speaker=-2,
id=segment_id,
start_speaker=start_token.start
)
else:
return cls(
@@ -139,53 +163,36 @@ class Segment(TimedText):
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker=-1,
id=segment_id,
start_speaker=start_token.start,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self) -> Dict[str, Any]:
"""Serialize the line for frontend consumption."""
"""Serialize the segment for frontend consumption (new API format)."""
_dict: Dict[str, Any] = {
'id': self.id if self.id is not None else 0,
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
'text': self.text or '',
'start_speaker': format_time(self.start_speaker) if self.start_speaker is not None else format_time(self.start),
'start': format_time(self.start),
'end': format_time(self.end),
'language': self.detected_language,
'translation': self.translation or '',
'buffer': self.buffer.to_dict() if self.buffer else SegmentBuffer().to_dict()
}
if self.translation:
_dict['translation'] = self.translation
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
@dataclass
class PuncSegment(Segment):
pass
class SilentLine(Line):
class SilentSegment(Segment):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
@@ -196,23 +203,20 @@ class SilentLine(Line):
class FrontData():
status: str = ''
error: str = ''
lines: list[Line] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
buffer_translation: str = ''
segments: list[Segment] = field(default_factory=list)
remaining_time_transcription: float = 0.
remaining_time_diarization: float = 0.
def to_dict(self) -> Dict[str, Any]:
"""Serialize the front-end data payload."""
"""Serialize the front-end data payload (new API format)."""
_dict: Dict[str, Any] = {
'type': 'transcript_update',
'status': self.status,
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
'buffer_transcription': self.buffer_transcription,
'buffer_diarization': self.buffer_diarization,
'buffer_translation': self.buffer_translation,
'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization,
'segments': [seg.to_dict() for seg in self.segments if (seg.text or seg.speaker == -2)],
'metadata': {
'remaining_time_transcription': self.remaining_time_transcription,
'remaining_time_diarization': self.remaining_time_diarization,
}
}
if self.error:
_dict['error'] = self.error

View File

@@ -1,10 +1,14 @@
from time import time
from typing import Optional, List, Tuple, Union, Any
from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
from whisperlivekit.timed_objects import (ASRToken, Segment, SegmentBuffer, PuncSegment, Silence,
SilentSegment, SpeakerSegment,
TimedText)
class TokensAlignment:
# Minimum duration (seconds) for a silence to be displayed
MIN_SILENCE_DISPLAY_DURATION = 2.0
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
self.state = state
@@ -25,6 +29,22 @@ class TokensAlignment:
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
self.validated_segments: List[Segment] = []
self.current_line_tokens: List[ASRToken] = []
self.diarization_buffer: List[ASRToken] = []
self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None
self.tokens_after_last_punctuation: PuncSegment = []
self.all_validated_segments: List[Segment] = []
# For token-by-token validation with diarization
self.pending_tokens: List[ASRToken] = []
self.last_validated_token_end: float = 0.0
# Segment ID counter for the new API
self._next_segment_id: int = 1
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
@@ -37,27 +57,27 @@ class TokensAlignment:
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '')
elif segment.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
previous_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
@@ -65,19 +85,47 @@ class TokensAlignment:
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
final_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = []
segment_start_idx = 0
self.tokens_after_last_punctuation += self.new_tokens
for i, token in enumerate(self.tokens_after_last_punctuation):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.tokens_after_last_punctuation[segment_start_idx: i],
)
if previous_segment:
new_punc_segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
new_punc_segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.tokens_after_last_punctuation[segment_start_idx: i+1],
)
new_punc_segments.append(segment)
segment_start_idx = i+1
self.tokens_after_last_punctuation = self.tokens_after_last_punctuation[segment_start_idx:]
return new_punc_segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
@@ -100,78 +148,227 @@ class TokensAlignment:
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
def _get_speaker_for_token(self, token: ASRToken, diarization_segments: List[SpeakerSegment]) -> Optional[int]:
"""Get speaker ID for a token based on diarization overlap. Returns None if not covered."""
if not diarization_segments:
return None
lines = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
# Check if token is beyond diarization coverage
if token.start >= diarization_segments[-1].end:
return None
# Find speaker with max overlap
max_overlap = 0.0
best_speaker = None
for diar_seg in diarization_segments:
overlap = self.intersection_duration(token, diar_seg)
if overlap > max_overlap:
max_overlap = overlap
best_speaker = diar_seg.speaker + 1 # 1-indexed
return best_speaker if max_overlap > 0 else None
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments with token-by-token validation when diarization covers them."""
diarization_segments = self.concatenate_diar_segments()
# Add new tokens to pending
self.pending_tokens.extend(self.new_tokens)
# Process pending tokens - validate those covered by diarization
still_pending = []
for token in self.pending_tokens:
if token.is_silence():
# Handle silence tokens
silence_duration = (token.end or 0) - (token.start or 0)
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
# Significant silence - add as separate segment
if self.all_validated_segments and not self.all_validated_segments[-1].is_silence():
self.all_validated_segments.append(SilentSegment(
start=token.start,
end=token.end
))
elif self.all_validated_segments and self.all_validated_segments[-1].is_silence():
# Extend existing silence
self.all_validated_segments[-1].end = token.end
else:
self.all_validated_segments.append(SilentSegment(
start=token.start,
end=token.end
))
# Short silences are ignored (don't go to pending either)
continue
speaker = self._get_speaker_for_token(token, diarization_segments)
if speaker is not None:
# Token is covered by diarization - validate it
if self.all_validated_segments:
last_seg = self.all_validated_segments[-1]
if not last_seg.is_silence() and last_seg.speaker == speaker:
# Same speaker - append to existing segment
last_seg.text += token.text
last_seg.end = token.end
else:
# Different speaker or after silence - new segment
new_seg = Segment(
start=token.start,
end=token.end,
text=token.text,
speaker=speaker,
start_speaker=token.start,
detected_language=token.detected_language
)
self.all_validated_segments.append(new_seg)
else:
lines.append(Line().build_from_segment(segment))
# First segment
new_seg = Segment(
start=token.start,
end=token.end,
text=token.text,
speaker=speaker,
start_speaker=token.start,
detected_language=token.detected_language
)
self.all_validated_segments.append(new_seg)
self.last_validated_token_end = token.end
else:
# Token not yet covered by diarization - keep pending
still_pending.append(token)
self.pending_tokens = still_pending
# Build diarization buffer from pending tokens
diarization_buffer = ''.join(t.text for t in self.pending_tokens if not t.is_silence())
return self.all_validated_segments, diarization_buffer
return lines, diarization_buffer
def _assign_segment_ids(self, segments: List[Segment]) -> None:
"""Assign unique IDs to segments that don't have one yet."""
for segment in segments:
if segment.id is None:
segment.id = self._next_segment_id
self._next_segment_id += 1
def _assign_buffers_to_last_segment(
self,
segments: List[Segment],
buffer_transcription: str,
buffer_diarization: str,
buffer_translation: str
) -> None:
"""Assign buffer content to the last non-silent segment."""
# First, clear ALL buffers (they're ephemeral and shouldn't persist)
for segment in segments:
segment.buffer = SegmentBuffer()
# Find the last non-silent segment and assign buffers to it
for segment in reversed(segments):
if not segment.is_silence():
segment.buffer = SegmentBuffer(
transcription=buffer_transcription,
diarization=buffer_diarization,
translation=buffer_translation
)
break
def _filter_and_merge_segments(self, segments: List[Segment]) -> List[Segment]:
"""Filter parasitic silences and merge consecutive same-speaker segments."""
if not segments:
return segments
result = []
for seg in segments:
if seg.is_silence():
# Filter short silences
duration = (seg.end or 0) - (seg.start or 0)
if duration < self.MIN_SILENCE_DISPLAY_DURATION:
continue
# Merge consecutive silences
if result and result[-1].is_silence():
result[-1].end = seg.end
continue
else:
# Merge same speaker segments (across filtered silences)
if result and not result[-1].is_silence() and result[-1].speaker == seg.speaker:
result[-1].text += seg.text
result[-1].end = seg.end
continue
result.append(seg)
return result
def get_lines(
self,
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
current_silence: Optional[Silence] = None,
buffer_transcription: str = ''
) -> List[Segment]:
"""Return the formatted segments with per-segment buffers, optionally with diarization/translation."""
diarization_buffer = ''
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
segments, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
for token in self.new_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
# Check silence duration before adding
silence_duration = (token.end or 0) - (token.start or 0)
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
else:
self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
self.current_line_tokens.append(token)
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens))
# Handle current ongoing silence
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
silence_duration = (current_silence.end or time() - self.beg_loop) - (current_silence.start or 0)
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if segments and segments[-1].is_silence():
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
else:
segments.append(SilentSegment(
start=current_silence.start,
end=end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
# Get translation buffer text
translation_buffer = self.new_translation_buffer.text if self.new_translation_buffer else ''
# Filter parasitic silences and merge same-speaker segments
segments = self._filter_and_merge_segments(segments)
# Assign unique IDs to all segments
self._assign_segment_ids(segments)
# Assign buffers to the last active segment
self._assign_buffers_to_last_segment(
segments,
buffer_transcription=buffer_transcription,
buffer_diarization=diarization_buffer,
buffer_translation=translation_buffer
)
return segments

View File

@@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5):
import os
import tempfile
import urllib.request
import librosa
if warmup_file == "":

View File

@@ -454,8 +454,9 @@ label {
gap: 4px;
}
.lag-diarization-value {
margin-left: 10px;
.lag-diarization-value,
.lag-transcription-value {
font-weight: 600;
}
.label_translation img {

View File

@@ -232,11 +232,8 @@ function setupWebSocket() {
if (waitingForStop) {
statusText.textContent = "Processing finalized or connection closed.";
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
renderSegments(
lastReceivedData.segments || [],
0,
0,
true
@@ -278,11 +275,8 @@ function setupWebSocket() {
waitingForStop = false;
if (lastReceivedData) {
renderLinesWithBuffer(
lastReceivedData.lines || [],
lastReceivedData.buffer_diarization || "",
lastReceivedData.buffer_transcription || "",
lastReceivedData.buffer_translation || "",
renderSegments(
lastReceivedData.segments || [],
0,
0,
true
@@ -299,21 +293,20 @@ function setupWebSocket() {
lastReceivedData = data;
// New API format: segments with per-segment buffers, metadata wrapper
const {
lines = [],
buffer_transcription = "",
buffer_diarization = "",
buffer_translation = "",
remaining_time_transcription = 0,
remaining_time_diarization = 0,
segments = [],
metadata = {},
status = "active_transcription",
} = data;
const {
remaining_time_transcription = 0,
remaining_time_diarization = 0,
} = metadata;
renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
buffer_translation,
renderSegments(
segments,
remaining_time_diarization,
remaining_time_transcription,
false,
@@ -323,11 +316,8 @@ function setupWebSocket() {
});
}
function renderLinesWithBuffer(
lines,
buffer_diarization,
buffer_transcription,
buffer_translation,
function renderSegments(
segments,
remaining_time_diarization,
remaining_time_transcription,
isFinalizing = false,
@@ -339,33 +329,38 @@ function renderLinesWithBuffer(
return;
}
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
// Build signature for change detection
const signature = JSON.stringify({
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
buffer_transcription: buffer_transcription || "",
buffer_diarization: buffer_diarization || "",
buffer_translation: buffer_translation,
segments: (segments || []).map((it) => ({
id: it.id,
speaker: it.speaker,
text: it.text,
start: it.start,
end: it.end,
language: it.language,
buffer: it.buffer || {}
})),
status: current_status,
showLoading,
showTransLag,
showDiaLag,
isFinalizing: !!isFinalizing,
});
// Only update lag values if signature unchanged
if (lastSignature === signature) {
const t = document.querySelector(".lag-transcription-value");
if (t) t.textContent = fmt1(remaining_time_transcription);
const d = document.querySelector(".lag-diarization-value");
if (d) d.textContent = fmt1(remaining_time_diarization);
const ld = document.querySelector(".loading-diarization-value");
if (ld) ld.textContent = fmt1(remaining_time_diarization);
return;
}
lastSignature = signature;
const linesHtml = (lines || [])
const segmentsHtml = (segments || [])
.map((item, idx) => {
const buffer = item.buffer || {};
const buffer_transcription = buffer.transcription || "";
const buffer_diarization = buffer.diarization || "";
const buffer_translation = buffer.translation || "";
let timeInfo = "";
if (item.start !== undefined && item.end !== undefined) {
timeInfo = ` ${item.start} - ${item.end}`;
@@ -373,80 +368,78 @@ function renderLinesWithBuffer(
let speakerLabel = "";
if (item.speaker === -2) {
// Silence segment
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) {
// Normal speaker segment
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) {
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
if (item.language) {
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.language}</span></span>`;
}
}
let currentLineText = item.text || "";
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
remaining_time_transcription
)}</span>s</span></span>`;
if (buffer_diarization && remaining_time_diarization) {
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
remaining_time_diarization
)}</span>s</span></span>`;
}
const isLastSegment = idx === segments.length - 1;
const hasBufferContent = buffer_diarization || buffer_transcription;
// Show lag indicators on last non-silent segment (without spinners)
if (isLastSegment && item.speaker !== -2 && !isFinalizing) {
if (remaining_time_transcription > 0) {
speakerLabel += `<span class="label_transcription">Transcription lag: <span class="lag-transcription-value">${fmt1(remaining_time_transcription)}</span>s</span>`;
}
if (buffer_diarization && remaining_time_diarization > 0) {
speakerLabel += `<span class="label_diarization">Diarization lag: <span class="lag-diarization-value">${fmt1(remaining_time_diarization)}</span>s</span>`;
}
}
// Render buffers
if (hasBufferContent && item.speaker !== -2) {
if (buffer_diarization) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_diarization.trim();
} else {
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
}
}
if (buffer_transcription) {
if (isFinalizing) {
currentLineText +=
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
buffer_transcription.trim();
currentLineText += (currentLineText.length > 0 ? " " : "") + buffer_transcription.trim();
} else {
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
}
}
}
// Translation
let translationContent = "";
if (item.translation) {
translationContent += item.translation.trim();
}
if (idx === lines.length - 1 && buffer_translation) {
if (buffer_translation) {
const bufferPiece = isFinalizing
? buffer_translation
: `<span class="buffer_translation">${buffer_translation}</span>`;
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
translationContent += translationContent ? bufferPiece : bufferPiece;
}
if (translationContent.trim().length > 0) {
currentLineText += `
<div>
<div class="label_translation">
${translationIcon}
<span class="translation_text">${translationContent}</span>
</div>
</div>`;
<div class="label_translation">
${translationIcon}
<span class="translation_text">${translationContent}</span>
</div>`;
}
return currentLineText.trim().length > 0 || speakerLabel.length > 0
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
if (currentLineText.trim().length > 0 || speakerLabel.length > 0) {
return `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`;
}
return speakerLabel ? `<p>${speakerLabel}</p>` : "";
})
.filter(html => html.length > 0)
.join("");
linesTranscriptDiv.innerHTML = linesHtml;
linesTranscriptDiv.innerHTML = segmentsHtml;
const transcriptContainer = document.querySelector('.transcript-container');
if (transcriptContainer) {
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });

View File

@@ -0,0 +1,377 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WhisperLiveKit Transcript</title>
<style>
:root {
--bg: #111;
--text: #ddd;
--dim: #666;
--border: #333;
--active: #e74c3c;
}
body {
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
background: var(--bg);
color: var(--text);
margin: 0;
padding: 2rem;
font-size: 13px;
line-height: 1.6;
}
.nav {
display: flex;
gap: 12px;
align-items: center;
margin-bottom: 3rem;
font-size: 12px;
}
button, input, select {
background: transparent;
border: 1px solid var(--border);
color: var(--dim);
padding: 6px 12px;
font-family: inherit;
font-size: inherit;
border-radius: 4px;
outline: none;
transition: all 0.2s;
}
button:hover, input:hover, input:focus, select:hover, select:focus {
border-color: var(--text);
color: var(--text);
cursor: pointer;
}
select {
cursor: pointer;
appearance: none; /* Minimalist look */
background-image: linear-gradient(45deg, transparent 50%, var(--dim) 50%), linear-gradient(135deg, var(--dim) 50%, transparent 50%);
background-position: calc(100% - 15px) 50%, calc(100% - 10px) 50%;
background-size: 5px 5px, 5px 5px;
background-repeat: no-repeat;
padding-right: 25px;
}
select:hover, select:focus {
background-image: linear-gradient(45deg, transparent 50%, var(--text) 50%), linear-gradient(135deg, var(--text) 50%, transparent 50%);
}
button.recording {
border-color: var(--active);
color: var(--active);
}
input {
width: 150px;
cursor: text;
}
#status {
margin-left: auto;
color: var(--dim);
}
#transcript {
white-space: pre-wrap;
word-wrap: break-word;
max-width: 800px;
margin: 0 auto;
outline: none;
}
/* Minimal scrollbar */
::-webkit-scrollbar { width: 6px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #222; border-radius: 3px; }
::-webkit-scrollbar-thumb:hover { background: #333; }
</style>
</head>
<body>
<div class="nav">
<button id="recordBtn">Record</button>
<button id="copyBtn">Copy</button>
<select id="microphoneSelect"></select>
<input type="text" id="wsUrl" placeholder="WebSocket URL">
<div id="status">Ready</div>
</div>
<div id="transcript"></div>
<script>
const recordBtn = document.getElementById('recordBtn');
const copyBtn = document.getElementById('copyBtn');
const wsUrlInput = document.getElementById('wsUrl');
const statusEl = document.getElementById('status');
const transcriptEl = document.getElementById('transcript');
const microphoneSelect = document.getElementById('microphoneSelect');
// Default WebSocket URL
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
const host = window.location.hostname || 'localhost';
const port = window.location.port;
const defaultUrl = `${protocol}://${host}${port ? ':' + port : ''}/asr`;
wsUrlInput.value = defaultUrl;
let websocket = null;
let isRecording = false;
let audioContext = null;
let workletNode = null;
let recorderWorker = null;
let microphone = null;
let useAudioWorklet = false;
let recorder = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
async function enumerateMicrophones() {
try {
// Request permission first to get labels
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
} catch (error) {
console.error('Error enumerating microphones:', error);
statusEl.textContent = "Mic permission needed";
}
}
function populateMicrophoneSelect() {
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
if (isRecording) {
stopRecording();
setTimeout(() => {
startRecording();
}, 500);
}
}
microphoneSelect.addEventListener('change', handleMicrophoneChange);
// Initial enumeration
enumerateMicrophones();
navigator.mediaDevices.addEventListener('devicechange', enumerateMicrophones);
function formatSegment(segment) {
const speaker = segment.speaker;
const text = segment.text || '';
const buffer = segment.buffer || {};
const start = segment.start || '';
const end = segment.end || '';
const language = segment.language || '';
let output = '';
// Silence marker
if (speaker === -2) {
output += `[SILENCE ${start} - ${end}]\n`;
return output;
}
// Speaker header
output += `[SPEAKER ${speaker}]`;
if (start && end) output += ` ${start} - ${end}`;
if (language) output += ` [LANG: ${language}]`;
output += '\n';
// Main text
if (text) {
output += text;
}
// Diarization buffer (text waiting for speaker assignment)
if (buffer.diarization) {
output += `[DIAR_BUFFER]${buffer.diarization}[/DIAR_BUFFER]`;
}
// Transcription buffer (text waiting for validation)
if (buffer.transcription) {
output += `[TRANS_BUFFER]${buffer.transcription}[/TRANS_BUFFER]`;
}
output += '\n';
// Translation
if (segment.translation) {
output += `[TRANSLATION]${segment.translation}`;
if (buffer.translation) {
output += `[TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER]`;
}
output += `[/TRANSLATION]\n`;
} else if (buffer.translation) {
output += `[TRANSLATION][TRANS_BUFFER]${buffer.translation}[/TRANS_BUFFER][/TRANSLATION]\n`;
}
return output;
}
function renderTranscript(data) {
const { segments = [], metadata = {}, status: msgStatus } = data;
if (msgStatus === 'no_audio_detected') {
// transcriptEl.textContent = '[NO AUDIO DETECTED]';
// Minimalist: maybe just don't show anything or show status
statusEl.textContent = 'No audio detected';
return;
}
let output = '';
// Metadata header
const remainingTrans = metadata.remaining_time_transcription || 0;
const remainingDiar = metadata.remaining_time_diarization || 0;
if (remainingTrans > 0 || remainingDiar > 0) {
output += `[LAG: trans=${remainingTrans.toFixed(1)}s diar=${remainingDiar.toFixed(1)}s]\n\n`;
}
// All segments
for (const segment of segments) {
output += formatSegment(segment);
output += '\n';
}
transcriptEl.textContent = output;
transcriptEl.scrollTop = transcriptEl.scrollHeight;
}
async function startRecording() {
try {
websocket = new WebSocket(wsUrlInput.value);
websocket.onopen = async () => {
statusEl.textContent = 'Connecting...';
};
websocket.onmessage = async (event) => {
const data = JSON.parse(event.data);
if (data.type === 'config') {
useAudioWorklet = !!data.useAudioWorklet;
statusEl.textContent = 'Recording';
await initAudio();
return;
}
if (data.type === 'ready_to_stop') {
statusEl.textContent = 'Done';
return;
}
// transcript_update
renderTranscript(data);
};
websocket.onclose = () => {
statusEl.textContent = 'Disconnected';
stopRecording(false);
};
websocket.onerror = () => {
statusEl.textContent = 'Error';
};
} catch (err) {
statusEl.textContent = 'Error: ' + err.message;
}
}
async function initAudio() {
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
audioContext = new (window.AudioContext || window.webkitAudioContext)();
microphone = audioContext.createMediaStreamSource(stream);
if (useAudioWorklet) {
await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');
workletNode = new AudioWorkletNode(audioContext, 'pcm-forwarder', {
numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1
});
microphone.connect(workletNode);
recorderWorker = new Worker('/web/recorder_worker.js');
recorderWorker.postMessage({ command: 'init', config: { sampleRate: audioContext.sampleRate } });
recorderWorker.onmessage = (e) => {
if (websocket?.readyState === WebSocket.OPEN) {
websocket.send(e.data.buffer);
}
};
workletNode.port.onmessage = (e) => {
const ab = e.data instanceof ArrayBuffer ? e.data : e.data.buffer;
recorderWorker.postMessage({ command: 'record', buffer: ab }, [ab]);
};
} else {
try {
recorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
} catch {
recorder = new MediaRecorder(stream);
}
recorder.ondataavailable = (e) => {
if (websocket?.readyState === WebSocket.OPEN && e.data?.size > 0) {
websocket.send(e.data);
}
};
recorder.start(100);
}
isRecording = true;
recordBtn.textContent = 'Stop';
recordBtn.classList.add('recording');
}
function stopRecording(sendStop = true) {
if (sendStop && websocket?.readyState === WebSocket.OPEN) {
websocket.send(new Blob([], { type: 'audio/webm' }));
}
if (recorder) { try { recorder.stop(); } catch {} recorder = null; }
if (recorderWorker) { recorderWorker.terminate(); recorderWorker = null; }
if (workletNode) { workletNode.disconnect(); workletNode = null; }
if (microphone) { microphone.disconnect(); microphone = null; }
if (audioContext) { audioContext.close(); audioContext = null; }
isRecording = false;
recordBtn.textContent = 'Record';
recordBtn.classList.remove('recording');
}
recordBtn.addEventListener('click', () => {
if (!isRecording) {
startRecording();
} else {
stopRecording();
}
});
copyBtn.addEventListener('click', () => {
navigator.clipboard.writeText(transcriptEl.textContent).then(() => {
const original = copyBtn.textContent;
copyBtn.textContent = 'Copied';
setTimeout(() => { copyBtn.textContent = original; }, 1500);
});
});
</script>
</body>
</html>

View File

@@ -1,6 +1,6 @@
import logging
import importlib.resources as resources
import base64
import importlib.resources as resources
import logging
logger = logging.getLogger(__name__)
@@ -13,6 +13,37 @@ def get_web_interface_html():
logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>"
def get_text_transcript_html():
"""Loads the simple text-based transcript HTML for easy copy/paste."""
try:
with resources.files('whisperlivekit.web').joinpath('text_transcript.html').open('r', encoding='utf-8') as f:
html_content = f.read()
# Inline the worker scripts
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
worklet_code = f.read()
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
worker_code = f.read()
html_content = html_content.replace(
"await audioContext.audioWorklet.addModule('/web/pcm_worklet.js');",
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
'await audioContext.audioWorklet.addModule(workletUrl);'
)
html_content = html_content.replace(
"recorderWorker = new Worker('/web/recorder_worker.js');",
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
'recorderWorker = new Worker(workerUrl);'
)
return html_content
except Exception as e:
logger.error(f"Error loading text transcript HTML: {e}")
return "<html><body><h1>Error loading text interface</h1></body></html>"
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
@@ -96,11 +127,13 @@ def get_inline_ui_html():
if __name__ == '__main__':
import pathlib
import uvicorn
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import uvicorn
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
app = FastAPI()

View File

@@ -4,18 +4,22 @@ import json
import os
import urllib
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from pathlib import Path
from torch import Tensor
from tqdm import tqdm
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
pad_or_trim)
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__
from whisperlivekit.whisper.lora import (LoRAAdapter, LoRAAdapterManager,
LoRAConfig, LoRALinear)
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
@@ -262,9 +266,49 @@ def _collapse_hf_module_name(module: str):
return module
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
"""
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
If lora_path is a local directory containing adapter files, returns it as-is.
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
"""
if not lora_path:
return None
# Check if it's already a valid local path
if os.path.isdir(lora_path):
config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.isfile(config_path):
return lora_path
# Try to download from HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
return local_path
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
)
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
if not lora_path:
return
# Resolve path (handles HuggingFace Hub download)
lora_path = _resolve_lora_path(lora_path)
if not lora_path:
return
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
@@ -317,6 +361,75 @@ def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str])
)
def _load_checkpoint(
file_path: Union[str, Path],
device: str,
in_memory: bool = False,
checkpoint_bytes: Optional[bytes] = None,
) -> Dict[str, torch.Tensor]:
"""
Load a checkpoint from a single file.
Handles .pt, .bin, and .safetensors formats.
"""
if checkpoint_bytes is not None:
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
)
return load_file(str(file_path), device=device)
else:
if in_memory:
with open(file_path, "rb") as f:
checkpoint_bytes = f.read()
with io.BytesIO(checkpoint_bytes) as fp:
return torch.load(fp, map_location=device)
else:
with open(file_path, "rb") as fp:
return torch.load(fp, map_location=device)
def _load_sharded_checkpoint(
shard_files: List[Path],
device: str,
) -> Dict[str, torch.Tensor]:
"""
Load a sharded checkpoint (multiple .safetensors or .bin files).
Merges all shards into a single state dict.
"""
merged_state_dict = {}
first_suffix = shard_files[0].suffix.lower()
if first_suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError(
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
)
for shard_path in shard_files:
shard_dict = load_file(str(shard_path), device=device)
merged_state_dict.update(shard_dict)
else:
for shard_path in shard_files:
with open(shard_path, "rb") as fp:
shard_dict = torch.load(fp, map_location=device)
if isinstance(shard_dict, dict):
merged_state_dict.update(shard_dict)
return merged_state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
@@ -334,6 +447,8 @@ def load_model(
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
or a sharded model directory with files like model-00001-of-00002.safetensors.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
@@ -348,16 +463,51 @@ def load_model(
model : Whisper
The Whisper ASR model instance
"""
from whisperlivekit.model_paths import detect_model_format
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
checkpoint = None
model_path_for_config = name # Used to find config.json for dims inference
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
if in_memory:
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
else:
checkpoint = _load_checkpoint(checkpoint_file, device)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
if in_memory:
with open(name, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(name, device)
model_path_for_config = name
elif os.path.isdir(name):
model_info = detect_model_format(name)
if not model_info.has_pytorch:
raise RuntimeError(
f"No PyTorch checkpoint found in directory {name}. "
f"Expected .pt, .bin, or .safetensors file(s)."
)
if model_info.is_sharded:
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
else:
single_file = model_info.pytorch_files[0]
if in_memory:
with open(single_file, "rb") as f:
checkpoint_bytes = f.read()
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
else:
checkpoint = _load_checkpoint(single_file, device)
model_path_for_config = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
@@ -367,22 +517,6 @@ def load_model(
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
@@ -394,7 +528,7 @@ def load_model(
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
dims = _infer_dims_from_config(model_path_for_config)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
@@ -419,6 +553,94 @@ def load_model(
return model.to(device)
def load_model_with_lora_manager(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only: bool = False,
custom_alignment_heads: Optional[str] = None,
adapters: Optional[Dict[str, str]] = None,
) -> tuple:
"""
Load a Whisper model with a LoRA adapter manager for dynamic adapter swapping.
This allows you to load multiple LoRA adapters and switch between them at runtime
without keeping multiple full models in memory.
Parameters
----------
name : str
Model name or path (same as load_model)
device : Union[str, torch.device]
Device to load model on
download_root : str
Download directory for model files
in_memory : bool
Whether to preload model weights into host memory
decoder_only : bool
If True, only load the decoder (no encoder)
custom_alignment_heads : str
Custom alignment heads configuration
adapters : Dict[str, str]
Optional dict mapping adapter names to paths/HuggingFace repo IDs.
Example: {"french": "path/to/french-lora", "spanish": "user/spanish-whisper-lora"}
Returns
-------
model : Whisper
The base Whisper model (without any LoRA baked in)
manager : LoRAAdapterManager
The adapter manager for loading/switching adapters
Example
-------
>>> model, manager = load_model_with_lora_manager(
... "large-v3",
... adapters={
... "french": "path/to/french-lora",
... "spanish": "path/to/spanish-lora"
... }
... )
>>>
>>> # Switch to French adapter
>>> manager.set_adapter("french")
>>> result_fr = model.transcribe(audio_fr)
>>>
>>> # Switch to Spanish adapter
>>> manager.set_adapter("spanish")
>>> result_es = model.transcribe(audio_es)
>>>
>>> # Use base model without LoRA
>>> manager.set_adapter(None)
>>> result_base = model.transcribe(audio)
>>>
>>> # Check memory usage
>>> print(manager.get_memory_usage())
{'french': 12.5, 'spanish': 12.5} # MB per adapter
"""
# Load the base model WITHOUT any LoRA baked in
model = load_model(
name=name,
device=device,
download_root=download_root,
in_memory=in_memory,
decoder_only=decoder_only,
custom_alignment_heads=custom_alignment_heads,
lora_path=None, # Important: no baked-in LoRA
)
# Create the adapter manager
manager = LoRAAdapterManager(model)
# Load any provided adapters
if adapters:
for adapter_name, adapter_path in adapters.items():
manager.load_adapter(adapter_name, adapter_path)
return model, manager
def convert_encoder_to_coreml(
model_name = "base",
output_path= "whisper_encoder.mlpackage",

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
Tuple, Union)
import numpy as np
import torch
@@ -146,16 +147,13 @@ class PyTorchInference(Inference):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
key_modules = [block.attn.key for block in self.model.decoder.blocks]
value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules
self.kv_cache_ids = []
for block in self.model.decoder.blocks:
self.kv_cache_ids.append(block.attn.key_cache_id)
self.kv_cache_ids.append(block.attn.value_cache_id)
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
@@ -163,17 +161,14 @@ class PyTorchInference(Inference):
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for module in self.kv_modules:
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
for cache_id in self.kv_cache_ids:
if cache_id in self.kv_cache:
# update the key/value cache to contain the selected sequences
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
class SequenceRanker:

View File

@@ -0,0 +1,473 @@
"""
Dynamic LoRA adapter support for Whisper models.
This module enables loading a single base Whisper model and dynamically swapping
between multiple LoRA adapters at runtime, saving GPU memory when working with
multiple language-specific fine-tuned models.
Usage:
from whisperlivekit.whisper import load_model
from whisperlivekit.whisper.lora import LoRAAdapterManager
# Load base model without any LoRA baked in
model = load_model("large-v3", device="cuda")
# Create adapter manager
manager = LoRAAdapterManager(model)
# Load multiple adapters (small memory footprint each)
manager.load_adapter("french", "path/to/french-lora")
manager.load_adapter("spanish", "path/to/spanish-lora")
# Switch between adapters at runtime
manager.set_adapter("french")
result_fr = model.transcribe(audio_fr)
manager.set_adapter("spanish")
result_es = model.transcribe(audio_es)
# Disable LoRA (use base model only)
manager.set_adapter(None)
"""
import json
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, nn
from .model import Linear
@dataclass
class LoRAConfig:
"""Configuration for a LoRA adapter."""
r: int # LoRA rank
alpha: float # LoRA alpha (scaling factor)
target_modules: List[str] = field(default_factory=list)
@property
def scaling(self) -> float:
return self.alpha / self.r
@dataclass
class LoRAAdapter:
"""Holds the LoRA A/B weight matrices for a single adapter."""
name: str
config: LoRAConfig
# Maps target module name -> (A matrix, B matrix)
weights: Dict[str, Tuple[Tensor, Tensor]] = field(default_factory=dict)
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
dtype: torch.dtype = field(default=torch.float32)
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
"""Move adapter weights to specified device/dtype."""
self.device = device
if dtype is not None:
self.dtype = dtype
self.weights = {
name: (a.to(device=device, dtype=dtype or self.dtype),
b.to(device=device, dtype=dtype or self.dtype))
for name, (a, b) in self.weights.items()
}
return self
def memory_footprint_mb(self) -> float:
"""Return approximate memory usage in MB."""
total_bytes = 0
for a, b in self.weights.values():
total_bytes += a.numel() * a.element_size()
total_bytes += b.numel() * b.element_size()
return total_bytes / (1024 * 1024)
class LoRALinear(nn.Module):
"""
A Linear layer wrapper that supports dynamic LoRA injection.
The base weights remain unchanged. LoRA is applied additively during forward:
output = base_linear(x) + (x @ A @ B) * scaling
"""
def __init__(self, base_linear: Linear):
super().__init__()
self.base_linear = base_linear
self.lora_A: Optional[Tensor] = None
self.lora_B: Optional[Tensor] = None
self.scaling: float = 1.0
self._lora_enabled: bool = False
def set_lora(self, A: Optional[Tensor], B: Optional[Tensor], scaling: float = 1.0):
"""Set the LoRA matrices for this layer."""
self.lora_A = A
self.lora_B = B
self.scaling = scaling
self._lora_enabled = A is not None and B is not None
def clear_lora(self):
"""Remove LoRA from this layer."""
self.lora_A = None
self.lora_B = None
self._lora_enabled = False
def forward(self, x: Tensor) -> Tensor:
# Base linear output
out = self.base_linear(x)
# Add LoRA contribution if enabled
if self._lora_enabled and self.lora_A is not None and self.lora_B is not None:
# x: (..., in_features)
# A: (in_features, r)
# B: (r, out_features)
# lora_out: (..., out_features)
lora_out = (x @ self.lora_A.to(x.dtype)) @ self.lora_B.to(x.dtype)
out = out + lora_out * self.scaling
return out
# Delegate attribute access to base_linear for compatibility
@property
def weight(self):
return self.base_linear.weight
@property
def bias(self):
return self.base_linear.bias
@property
def in_features(self):
return self.base_linear.in_features
@property
def out_features(self):
return self.base_linear.out_features
# Mapping from HuggingFace LoRA module names to Whisper module paths
_HF_TO_WHISPER_MODULE_MAP = {
# Encoder attention
"model.encoder.layers.{}.self_attn.q_proj": "encoder.blocks.{}.attn.query",
"model.encoder.layers.{}.self_attn.k_proj": "encoder.blocks.{}.attn.key",
"model.encoder.layers.{}.self_attn.v_proj": "encoder.blocks.{}.attn.value",
"model.encoder.layers.{}.self_attn.out_proj": "encoder.blocks.{}.attn.out",
# Encoder MLP
"model.encoder.layers.{}.fc1": "encoder.blocks.{}.mlp.0",
"model.encoder.layers.{}.fc2": "encoder.blocks.{}.mlp.2",
# Decoder self-attention
"model.decoder.layers.{}.self_attn.q_proj": "decoder.blocks.{}.attn.query",
"model.decoder.layers.{}.self_attn.k_proj": "decoder.blocks.{}.attn.key",
"model.decoder.layers.{}.self_attn.v_proj": "decoder.blocks.{}.attn.value",
"model.decoder.layers.{}.self_attn.out_proj": "decoder.blocks.{}.attn.out",
# Decoder cross-attention
"model.decoder.layers.{}.encoder_attn.q_proj": "decoder.blocks.{}.cross_attn.query",
"model.decoder.layers.{}.encoder_attn.k_proj": "decoder.blocks.{}.cross_attn.key",
"model.decoder.layers.{}.encoder_attn.v_proj": "decoder.blocks.{}.cross_attn.value",
"model.decoder.layers.{}.encoder_attn.out_proj": "decoder.blocks.{}.cross_attn.out",
# Decoder MLP
"model.decoder.layers.{}.fc1": "decoder.blocks.{}.mlp.0",
"model.decoder.layers.{}.fc2": "decoder.blocks.{}.mlp.2",
}
def _normalize_hf_module_name(name: str) -> str:
"""Normalize HF-style LoRA module names."""
if name.startswith("base_model."):
name = name[len("base_model."):]
if name.startswith("model.model."):
name = name[len("model."):]
if not name.startswith("model."):
name = f"model.{name}"
return name
def _map_hf_to_whisper_module(hf_name: str) -> Optional[str]:
"""Map a HuggingFace LoRA module name to Whisper module path."""
hf_name = _normalize_hf_module_name(hf_name)
# Try to match with layer index patterns
import re
# Match patterns like model.encoder.layers.5.self_attn.q_proj
for pattern, target_pattern in _HF_TO_WHISPER_MODULE_MAP.items():
# Create regex from pattern (replace {} with capture group)
regex = pattern.replace(".", r"\.").replace("{}", r"(\d+)")
match = re.fullmatch(regex, hf_name)
if match:
layer_idx = match.group(1)
return target_pattern.format(layer_idx)
return None
def _get_module_by_path(model: nn.Module, path: str) -> Optional[nn.Module]:
"""Get a submodule by dot-separated path."""
parts = path.split(".")
current = model
for part in parts:
if hasattr(current, part):
current = getattr(current, part)
elif hasattr(current, "__getitem__"):
try:
current = current[int(part)]
except (ValueError, IndexError, KeyError):
return None
else:
return None
return current
def _set_module_by_path(model: nn.Module, path: str, module: nn.Module):
"""Set a submodule by dot-separated path."""
parts = path.split(".")
parent = model
for part in parts[:-1]:
if hasattr(parent, part):
parent = getattr(parent, part)
elif hasattr(parent, "__getitem__"):
parent = parent[int(part)]
setattr(parent, parts[-1], module)
class LoRAAdapterManager:
"""
Manages multiple LoRA adapters for a Whisper model.
Enables loading multiple adapters and switching between them at runtime
without reloading the full model.
"""
def __init__(self, model: nn.Module):
"""
Initialize the adapter manager.
Args:
model: A Whisper model instance
"""
self.model = model
self.adapters: Dict[str, LoRAAdapter] = {}
self.current_adapter: Optional[str] = None
self._lora_layers: Dict[str, LoRALinear] = {}
self._original_layers: Dict[str, Linear] = {}
self._initialized = False
def _initialize_lora_layers(self, target_modules: List[str]):
"""
Replace target Linear layers with LoRALinear wrappers.
This is done lazily on first adapter load.
"""
if self._initialized:
return
# Find and wrap all potential LoRA target modules
for whisper_path in target_modules:
module = _get_module_by_path(self.model, whisper_path)
if module is None:
continue
if isinstance(module, Linear) and not isinstance(module, LoRALinear):
# Wrap the Linear layer
lora_linear = LoRALinear(module)
_set_module_by_path(self.model, whisper_path, lora_linear)
self._lora_layers[whisper_path] = lora_linear
self._original_layers[whisper_path] = module
self._initialized = True
def _resolve_lora_path(self, lora_path: str) -> str:
"""Resolve LoRA path, downloading from HuggingFace Hub if needed."""
if os.path.isdir(lora_path):
return lora_path
# Try HuggingFace Hub
if "/" in lora_path:
try:
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=lora_path,
allow_patterns=["adapter_config.json", "adapter_model.*"],
)
except Exception as e:
raise FileNotFoundError(
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
)
raise FileNotFoundError(f"LoRA path '{lora_path}' not found.")
def _load_adapter_weights(self, lora_path: str) -> Dict[str, Tensor]:
"""Load adapter weights from safetensors or bin file."""
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
if os.path.isfile(safe_path):
from safetensors.torch import load_file
return load_file(safe_path)
elif os.path.isfile(bin_path):
return torch.load(bin_path, map_location="cpu")
else:
raise FileNotFoundError(
f"No adapter weights found in {lora_path}. "
"Expected adapter_model.safetensors or adapter_model.bin."
)
def load_adapter(
self,
name: str,
lora_path: str,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAAdapter:
"""
Load a LoRA adapter from disk or HuggingFace Hub.
Args:
name: Unique name for this adapter (e.g., "french", "spanish")
lora_path: Local path or HuggingFace repo ID
device: Device to load weights to (default: model's device)
dtype: Data type for weights (default: model's dtype)
Returns:
The loaded LoRAAdapter
"""
if device is None:
device = next(self.model.parameters()).device
if dtype is None:
dtype = next(self.model.parameters()).dtype
# Resolve path
lora_path = self._resolve_lora_path(lora_path)
# Load config
config_path = os.path.join(lora_path, "adapter_config.json")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Missing adapter_config.json in {lora_path}")
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
if config_dict.get("peft_type") != "LORA":
raise ValueError("Only LoRA adapters are supported.")
config = LoRAConfig(
r=config_dict["r"],
alpha=config_dict.get("lora_alpha") or config_dict.get("alpha"),
target_modules=config_dict.get("target_modules", []),
)
# Load weights
adapter_state = self._load_adapter_weights(lora_path)
# Parse LoRA A/B matrices and map to Whisper module paths
lora_layers: Dict[str, Dict[str, Tensor]] = {}
for key, tensor in adapter_state.items():
if key.endswith("lora_A.weight"):
module = key[:-len(".lora_A.weight")]
lora_layers.setdefault(module, {})["A"] = tensor
elif key.endswith("lora_B.weight"):
module = key[:-len(".lora_B.weight")]
lora_layers.setdefault(module, {})["B"] = tensor
# Map to Whisper module paths and collect weights
weights: Dict[str, Tuple[Tensor, Tensor]] = {}
whisper_paths = set()
for hf_module, parts in lora_layers.items():
if "A" not in parts or "B" not in parts:
continue
whisper_path = _map_hf_to_whisper_module(hf_module)
if whisper_path is None:
# Try direct mapping (module might already be in Whisper format)
whisper_path = hf_module
# A: (r, in_features) -> transpose to (in_features, r)
# B: (out_features, r) -> transpose to (r, out_features)
A = parts["A"].T # (in_features, r)
B = parts["B"].T # (r, out_features)
weights[whisper_path] = (A, B)
whisper_paths.add(whisper_path)
# Create adapter
adapter = LoRAAdapter(
name=name,
config=config,
weights=weights,
device=device,
dtype=dtype,
)
adapter.to(device, dtype)
# Initialize LoRA layers if not done yet
self._initialize_lora_layers(list(whisper_paths))
# Store adapter
self.adapters[name] = adapter
return adapter
def set_adapter(self, name: Optional[str]):
"""
Switch to a different adapter or disable LoRA.
Args:
name: Adapter name to activate, or None to disable all LoRA
"""
if name is not None and name not in self.adapters:
raise KeyError(f"Adapter '{name}' not loaded. Available: {list(self.adapters.keys())}")
# Clear all LoRA from layers
for lora_linear in self._lora_layers.values():
lora_linear.clear_lora()
self.current_adapter = name
if name is None:
return
# Apply the selected adapter
adapter = self.adapters[name]
for module_path, (A, B) in adapter.weights.items():
if module_path in self._lora_layers:
self._lora_layers[module_path].set_lora(A, B, adapter.config.scaling)
def unload_adapter(self, name: str):
"""
Unload an adapter from memory.
Args:
name: Name of adapter to unload
"""
if name not in self.adapters:
return
if self.current_adapter == name:
self.set_adapter(None)
del self.adapters[name]
def list_adapters(self) -> List[str]:
"""Return list of loaded adapter names."""
return list(self.adapters.keys())
def get_memory_usage(self) -> Dict[str, float]:
"""Return memory usage in MB for each loaded adapter."""
return {name: adapter.memory_footprint_mb() for name, adapter in self.adapters.items()}
def restore_original_layers(self):
"""
Restore the original Linear layers, removing LoRA wrappers.
Call this if you want to go back to the original model structure.
"""
for path, original in self._original_layers.items():
_set_module_by_path(self.model, path, original)
self._lora_layers.clear()
self._original_layers.clear()
self._initialized = False
self.current_adapter = None

View File

@@ -79,18 +79,23 @@ def disable_sdpa():
class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
use_sdpa = False # Disable SDPA to ensure qk is always computed when needed
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448):
super().__init__()
self.n_head = n_head
self.n_text_ctx = n_text_ctx
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key"
self.value.cache_id = f"{cache_id}_value"
# Cache IDs for key and value (used with dict-based kv_cache)
self.key_cache_id = f"{cache_id}_key"
self.value_cache_id = f"{cache_id}_value"
# Keep these for backward compatibility with hook-based caching
self.key.cache_id = self.key_cache_id
self.value.cache_id = self.value_cache_id
def forward(
self,
@@ -101,19 +106,45 @@ class MultiHeadAttention(nn.Module):
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
if xa is None:
# Self-attention
k = self.key(x)
v = self.value(x)
if kv_cache is not None:
k, v = self._update_self_attn_cache(k, v, kv_cache)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
# Cross-attention: compute once and cache, or reuse from cache
if kv_cache is not None and self.key_cache_id in kv_cache:
k = kv_cache[self.key_cache_id]
v = kv_cache[self.value_cache_id]
else:
k = self.key(xa)
v = self.value(xa)
if kv_cache is not None:
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def _update_self_attn_cache(
self, k: Tensor, v: Tensor, kv_cache: dict
) -> Tuple[Tensor, Tensor]:
"""Update self-attention kv cache by concatenating new k,v with cached values."""
if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx:
# First token or context overflow: save as-is
kv_cache[self.key_cache_id] = k.detach()
kv_cache[self.value_cache_id] = v.detach()
else:
# Concatenate with existing cache
cached_k = kv_cache[self.key_cache_id]
cached_v = kv_cache[self.value_cache_id]
k = torch.cat([cached_k, k], dim=1).detach()
v = torch.cat([cached_v, v], dim=1).detach()
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
return k, v
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -143,14 +174,21 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False,
cache_id: str = "", n_text_ctx: int = 448
):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
self.attn = MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx
)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx
) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -166,12 +204,21 @@ class ResidualAttentionBlock(nn.Module):
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Returns:
x: The output tensor
cross_attn_qk: Cross-attention weights (if cross_attn exists), else None
"""
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
cross_attn_qk = None
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
cross_out, cross_attn_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=kv_cache
)
x = x + cross_out
x = x + self.mlp(self.mlp_ln(x))
return x
return x, cross_attn_qk
class AudioEncoder(nn.Module):
@@ -201,7 +248,7 @@ class AudioEncoder(nn.Module):
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
x, _ = block(x) # Encoder blocks don't have cross-attention
x = self.ln_post(x)
return x
@@ -212,13 +259,17 @@ class TextDecoder(nn.Module):
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.n_ctx = n_ctx
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
ResidualAttentionBlock(
n_state, n_head, cross_attention=True,
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
)
for i in range(n_layer)
]
)
@@ -227,28 +278,57 @@ class TextDecoder(nn.Module):
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
def forward(
self,
x: Tensor,
xa: Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
kv_cache : Optional[dict]
Dictionary to store/retrieve key-value cache for efficient decoding
return_cross_attn : bool
If True, return cross-attention weights from all decoder layers
Returns
-------
logits : Tensor
The output logits
cross_attns : Optional[List[Tensor]]
List of cross-attention weights per layer (only if return_cross_attn=True)
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
# Calculate offset from self-attention cache (not cross-attention which has audio length)
offset = 0
if kv_cache:
# Use the first decoder block's self-attention key cache to get token position
first_self_attn_key = self.blocks[0].attn.key_cache_id
if first_self_attn_key in kv_cache:
offset = kv_cache[first_self_attn_key].shape[1]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
cross_attns = [] if return_cross_attn else None
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache)
if return_cross_attn and cross_attn_qk is not None:
cross_attns.append(cross_attn_qk)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
if return_cross_attn:
return logits, cross_attns
return logits
@@ -292,8 +372,18 @@ class Whisper(nn.Module):
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def logits(
self,
tokens: torch.Tensor,
audio_features: torch.Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
return self.decoder(
tokens, audio_features,
kv_cache=kv_cache,
return_cross_attn=return_cross_attn
)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
@@ -312,39 +402,6 @@ class Whisper(nn.Module):
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

View File

@@ -8,28 +8,13 @@ import numpy as np
import torch
import tqdm
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (
exact_div,
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
from .utils import (exact_div, format_timestamp, get_end, get_writer,
make_safe, optional_float, optional_int, str2bool)
if TYPE_CHECKING:
from .model import Whisper