7 Commits

Author SHA1 Message Date
Quentin Fuxa
60c62f8f84 troubleshooting #271 #276 #284 #286 2025-11-25 23:31:46 +01:00
Quentin Fuxa
7faa21f95f alignatt: enable model sharing by removing hooks and centralizing session state. Solves #282
Co-authored-by: Emmanuel Schmidbauer <eschmidbauer@gmail.com>
2025-11-25 23:07:42 +01:00
Quentin Fuxa
4e9f951551 correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
870141298c isort 2025-11-23 11:20:00 +01:00
Quentin Fuxa
872faa422a correct silences handling when language not auto 2025-11-20 11:20:00 +01:00
Quentin Fuxa
fc9cb66813 disabling vac is not advised 2025-11-23 11:20:00 +01:00
Quentin Fuxa
a175d1a327 fixes silence detected but never reported by silero 2025-11-23 11:20:00 +01:00
39 changed files with 823 additions and 781 deletions

View File

@@ -51,9 +51,11 @@ pip install whisperlivekit
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time! 2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages. > - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options. > - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent. > - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
#### Use it to capture audio from web pages. #### Use it to capture audio from web pages.
@@ -96,11 +98,13 @@ wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes. **Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
```python ```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
transcription_engine = None transcription_engine = None
@@ -146,8 +150,8 @@ async def websocket_endpoint(websocket: WebSocket):
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` | | `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
| `--no-vac` | Disable Voice Activity Controller | `False` | | `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
| `--no-vad` | Disable Voice Activity Detection | `False` | | `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` | | `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
| `--host` | Server host address | `localhost` | | `--host` | Server host address | `localhost` |
| `--port` | Server port | `8000` | | `--port` | Server port | `8000` |
@@ -183,7 +187,6 @@ async def websocket_endpoint(websocket: WebSocket):
| `--init-prompt` | Initial prompt for the model | `None` | | `--init-prompt` | Initial prompt for the model | `None` |
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` | | `--max-context-tokens` | Maximum context tokens | `None` |
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |

View File

@@ -1,258 +0,0 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

113
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,113 @@
# Troubleshooting
## GPU drivers & cuDNN visibility
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
> Reported in issue #271 (Arch/CachyOS)
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
```bash
sudo pacman -S cuda cudnn
sudo ldconfig
```
2. **Make sure the shared objects are visible**:
```bash
ls /usr/lib/libcudnn*
```
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
```bash
python - <<'EOF'
import torch
print(torch.version.cuda)
EOF
nvcc --version
```
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
### Windows error: `Could not locate cudnn_ops64_9.dll`
> Reported in issue #286 (Conda on Windows)
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
1. Locate the DLL shipped with PyTorch, e.g.
```
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
```
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
3. Restart the shell before launching `wlk`.
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
---
## PyTorch / CTranslate2 GPU builds
### `Torch not compiled with CUDA enabled`
> Reported in issue #284
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
Install the GPU-enabled wheels that match your CUDA toolkit:
```bash
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
```
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
Validate with:
```python
import torch
print(torch.cuda.is_available(), torch.cuda.get_device_name())
```
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
> Follow-up in issue #284
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
```bash
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
```
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
3. Verify:
```python
import ctranslate2
print("CUDA devices:", ctranslate2.get_cuda_device_count())
```
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
---
## Hopper / Blackwell (`sm_121a`) systems
> Reported in issue #276 (NVIDIA DGX Spark)
CUDA 12.1a GPUs ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual hints:
```bash
export CUDA_HOME="/usr/local/cuda-13.0"
export PATH="$CUDA_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
# Tell Triton where the new ptxas lives
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
# Force PyTorch to JIT kernels for all needed architectures
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
```
After exporting those variables (or adding them to your systemd service / shell profile), restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
---
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.

View File

@@ -61,10 +61,10 @@ packages = [
"whisperlivekit.whisper.normalizers", "whisperlivekit.whisper.normalizers",
"whisperlivekit.web", "whisperlivekit.web",
"whisperlivekit.local_agreement", "whisperlivekit.local_agreement",
"whisperlivekit.vad_models" "whisperlivekit.silero_vad_models"
] ]
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"] "whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -14,10 +14,10 @@ from typing import Dict, Tuple
import torch import torch
from whisperlivekit.whisper import _convert_hf_state_dict
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
from whisperlivekit.whisper.model import ModelDimensions from whisperlivekit.whisper.model import ModelDimensions
from whisperlivekit.whisper.utils import exact_div from whisperlivekit.whisper.utils import exact_div
from whisperlivekit.whisper import _convert_hf_state_dict
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]: def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:

View File

@@ -5,16 +5,18 @@ import argparse
import base64 import base64
import gzip import gzip
import io import io
import math
import pathlib import pathlib
import sys import sys
import math
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from datasets import Audio as DatasetAudio, load_dataset
import soundfile as sf
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from datasets import Audio as DatasetAudio
from datasets import load_dataset
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper" WHISPER_ROOT = REPO_ROOT / "whisper"

View File

@@ -1,9 +1,10 @@
"""Copy core files from web directory to Chrome extension directory.""" """Copy core files from web directory to Chrome extension directory."""
import shutil
import os import os
import shutil
from pathlib import Path from pathlib import Path
def sync_extension_files(): def sync_extension_files():
web_dir = Path("whisperlivekit/web") web_dir = Path("whisperlivekit/web")

View File

@@ -1,7 +1,7 @@
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args
from .web.web_interface import get_web_interface_html, get_inline_ui_html from .web.web_interface import get_inline_ui_html, get_web_interface_html
__all__ = [ __all__ = [
"TranscriptionEngine", "TranscriptionEngine",

View File

@@ -1,14 +1,20 @@
import asyncio import asyncio
import numpy as np
from time import time
import logging import logging
import traceback import traceback
from typing import Optional, Union, List, Any, AsyncGenerator from time import time
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker from typing import Any, AsyncGenerator, List, Optional, Union
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator import numpy as np
from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory,
online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Line, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@@ -603,16 +609,16 @@ class AudioProcessor:
res = self.vac(pcm_array) res = self.vac(pcm_array)
if res is not None: if res is not None:
silence_detected = res.get("end", 0) > res.get("start", 0) if "start" in res and self.current_silence:
if silence_detected and not self.current_silence: await self._end_silence()
if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence( pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end") pcm_array, chunk_sample_start, res.get("end")
) )
if pre_silence_chunk is not None and pre_silence_chunk.size > 0: if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
await self._enqueue_active_audio(pre_silence_chunk) await self._enqueue_active_audio(pre_silence_chunk)
await self._begin_silence() await self._begin_silence()
elif self.current_silence:
await self._end_silence()
if not self.current_silence: if not self.current_silence:
await self._enqueue_active_audio(pcm_array) await self._enqueue_active_audio(pcm_array)

View File

@@ -1,10 +1,13 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
get_inline_ui_html, parse_args)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)

View File

@@ -1,9 +1,11 @@
import logging
import sys
from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
from whisperlivekit.local_agreement.whisper_online import backend_factory from whisperlivekit.local_agreement.whisper_online import backend_factory
from whisperlivekit.simul_whisper import SimulStreamingASR from whisperlivekit.simul_whisper import SimulStreamingASR
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
from argparse import Namespace
import sys
import logging
def update_with_kwargs(_dict, kwargs): def update_with_kwargs(_dict, kwargs):
_dict.update({ _dict.update({
@@ -80,6 +82,7 @@ class TranscriptionEngine:
if self.args.vac: if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad from whisperlivekit.silero_vad_iterator import load_silero_vad
# Use ONNX if specified, otherwise use JIT (default) # Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False) use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx) self.vac_model = load_silero_vad(onnx=use_onnx)
@@ -100,7 +103,6 @@ class TranscriptionEngine:
"init_prompt": None, "init_prompt": None,
"static_init_prompt": None, "static_init_prompt": None,
"max_context_tokens": None, "max_context_tokens": None,
"preload_model_count": 1,
} }
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs) simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
@@ -135,7 +137,8 @@ class TranscriptionEngine:
if self.args.diarization: if self.args.diarization:
if self.args.diarization_backend == "diart": if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization from whisperlivekit.diarization.diart_backend import \
DiartDiarization
diart_params = { diart_params = {
"segmentation_model": "pyannote/segmentation-3.0", "segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding", "embedding_model": "pyannote/embedding",
@@ -146,7 +149,8 @@ class TranscriptionEngine:
**diart_params **diart_params
) )
elif self.args.diarization_backend == "sortformer": elif self.args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarization
self.diarization_model = SortformerDiarization() self.diarization_model = SortformerDiarization()
self.translation_model = None self.translation_model = None
@@ -182,7 +186,8 @@ def online_diarization_factory(args, diarization_backend):
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
if args.diarization_backend == "sortformer": if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline from whisperlivekit.diarization.sortformer_backend import \
SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend) online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online return online

View File

@@ -1,20 +1,20 @@
import asyncio import asyncio
import logging
import re import re
import threading import threading
import numpy as np
import logging
import time import time
from queue import SimpleQueue, Empty from queue import Empty, SimpleQueue
from typing import Any, List, Tuple
import diart.models as m
import numpy as np
from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.inference import StreamingInference from diart.inference import StreamingInference
from diart.sources import AudioSource from diart.sources import AudioSource, MicrophoneAudioSource
from whisperlivekit.timed_objects import SpeakerSegment
from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation from pyannote.core import Annotation
import diart.models as m from rx.core import Observer
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,11 +1,12 @@
import numpy as np
import torch
import logging import logging
import threading import threading
import time import time
import wave import wave
from queue import Empty, SimpleQueue
from typing import List, Optional from typing import List, Optional
from queue import SimpleQueue, Empty
import numpy as np
import torch
from whisperlivekit.timed_objects import SpeakerSegment from whisperlivekit.timed_objects import SpeakerSegment
@@ -295,6 +296,7 @@ def extract_number(s: str) -> int:
if __name__ == '__main__': if __name__ == '__main__':
import asyncio import asyncio
import librosa import librosa
async def main(): async def main():

View File

@@ -1,8 +1,8 @@
import asyncio import asyncio
import contextlib
import logging import logging
from enum import Enum from enum import Enum
from typing import Optional, Callable from typing import Callable, Optional
import contextlib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@@ -1,13 +1,16 @@
import sys
import logging
import io import io
import soundfile as sf import logging
import math import math
import sys
from typing import List from typing import List
import numpy as np import numpy as np
import soundfile as sf
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.timed_objects import ASRToken from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ASRBase: class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped, sep = " " # join transcribe words with this character (" " for whisper_timestamped,
@@ -165,8 +168,8 @@ class MLXWhisper(ASRBase):
sep = "" sep = ""
def load_model(self, model_size=None, cache_dir=None, model_dir=None): def load_model(self, model_size=None, cache_dir=None, model_dir=None):
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx import mlx.core as mx
from mlx_whisper.transcribe import ModelHolder, transcribe
if model_dir is not None: if model_dir is not None:
resolved_path = resolve_model_path(model_dir) resolved_path = resolve_model_path(model_dir)

View File

@@ -1,7 +1,9 @@
import sys
import numpy as np
import logging import logging
from typing import List, Tuple, Optional import sys
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,18 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging import logging
import platform import platform
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR import sys
import time
from functools import lru_cache
import librosa
import numpy as np
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.warmup import warmup_asr from whisperlivekit.warmup import warmup_asr
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
from whisperlivekit.backend_support import ( from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
mlx_backend_available,
faster_backend_available,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
def parse_args(): def parse_args():
parser = ArgumentParser(description="Whisper FastAPI Online Server") parser = ArgumentParser(description="Whisper FastAPI Online Server")
parser.add_argument( parser.add_argument(
@@ -295,14 +296,6 @@ def parse_args():
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.", help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
) )
simulstreaming_group.add_argument(
"--preload-model-count",
type=int,
default=1,
dest="preload_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--nllb-backend", "--nllb-backend",
type=str, type=str,

View File

@@ -1,8 +1,9 @@
import torch
import numpy as np
import warnings import warnings
from pathlib import Path from pathlib import Path
import numpy as np
import torch
""" """
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
""" """
@@ -123,7 +124,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
raise Exception(f'Available ONNX opset_version: {available_ops}') raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'vad_models' data_dir = current_dir / 'silero_vad_models'
if onnx: if onnx:
if opset_version == 16: if opset_version == 16:
@@ -138,7 +139,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
if not model_path.exists(): if not model_path.exists():
raise FileNotFoundError( raise FileNotFoundError(
f"Model file not found: {model_path}\n" f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files." f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
@@ -276,8 +277,10 @@ class FixedVADIterator(VADIterator):
elif r is not None: elif r is not None:
if "end" in r: if "end" in r:
ret["end"] = r["end"] ret["end"] = r["end"]
if "start" in r and "end" in ret: if "start" in r:
del ret["end"] ret["start"] = r["start"]
if "end" in ret:
del ret["end"]
return ret if ret != {} else None return ret if ret != {} else None

View File

@@ -1,31 +1,30 @@
import sys import gc
import numpy as np
import logging import logging
from typing import List, Tuple, Optional import os
import platform import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker import sys
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from whisperlivekit.whisper import load_model, tokenizer from whisperlivekit.whisper import load_model, tokenizer
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
import os
import gc
from pathlib import Path
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
from whisperlivekit.backend_support import (
mlx_backend_available,
faster_backend_available,
)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
@@ -50,20 +49,19 @@ class SimulStreamingOnlineProcessor:
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = [] self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = []
self.load_new_backend() self.load_new_alignatt_instance()
#can be moved
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
def load_new_backend(self): def load_new_alignatt_instance(self):
model = self.asr.get_new_model_instance() """Initialize AlignAtt decoder using the shared model."""
self.model = AlignAtt( self.model = AlignAtt(
cfg=self.asr.cfg, cfg=self.asr.cfg,
loaded_model=model, loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder, mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder, fw_encoder=self.asr.fw_encoder,
) )
def start_silence(self): def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True) tokens, processed_upto = self.process_iter(is_last=True)
@@ -71,7 +69,10 @@ class SimulStreamingOnlineProcessor:
def end_silence(self, silence_duration, offset): def end_silence(self, silence_duration, offset):
""" """
If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
""" """
self.end += silence_duration self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
@@ -84,21 +85,20 @@ class SimulStreamingOnlineProcessor:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.global_time_offset = silence_duration + offset self.model.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor # Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float() audio_tensor = torch.from_numpy(audio).float()
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
self.model.insert_audio(audio_tensor) self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker): def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True) """Handle speaker change event."""
self.model.refresh_segment(complete=True) self.process_iter(is_last=True)
self.model.speaker = change_speaker.speaker self.model.refresh_segment(complete=True)
self.global_time_offset = change_speaker.start self.model.speaker = change_speaker.speaker
self.model.global_time_offset = change_speaker.start
def get_buffer(self): def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
@@ -112,15 +112,17 @@ class SimulStreamingOnlineProcessor:
""" """
try: try:
timestamped_words = self.model.infer(is_last=is_last) timestamped_words = self.model.infer(is_last=is_last)
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
if not timestamped_words:
return [], self.end
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.committed.extend(timestamped_words) self.committed.extend(timestamped_words)
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}") logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end return [], self.end
@@ -136,12 +138,8 @@ class SimulStreamingOnlineProcessor:
logger.exception(f"SimulStreaming warmup failed: {e}") logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self): def __del__(self):
# free the model and add a new model to stack.
# del self.model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# self.asr.new_model_to_stack()
self.model.remove_hooks()
class SimulStreamingASR(): class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
@@ -226,10 +224,7 @@ class SimulStreamingASR():
self.tokenizer = self.set_translate_task() self.tokenizer = self.set_translate_task()
else: else:
self.tokenizer = None self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None self.mlx_encoder, self.fw_encoder = None, None
if self.encoder_backend == "mlx-whisper": if self.encoder_backend == "mlx-whisper":
print('Simulstreaming will use MLX whisper to increase encoding speed.') print('Simulstreaming will use MLX whisper to increase encoding speed.')
@@ -253,8 +248,7 @@ class SimulStreamingASR():
device='auto', device='auto',
compute_type='auto', compute_type='auto',
) )
self.shared_model = self.load_model()
self.models = [self.load_model() for i in range(self.preload_model_count)]
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
@@ -303,11 +297,11 @@ class SimulStreamingASR():
download_root=self.model_path, download_root=self.model_path,
decoder_only=self.fast_encoder, decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads custom_alignment_heads=self.custom_alignment_heads
) )
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None: if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float() warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder: if self.fast_encoder:
temp_model = AlignAtt( temp_model = AlignAtt(
cfg=self.cfg, cfg=self.cfg,
loaded_model=whisper_model, loaded_model=whisper_model,
@@ -315,27 +309,9 @@ class SimulStreamingASR():
fw_encoder=self.fw_encoder, fw_encoder=self.fw_encoder,
) )
temp_model.warmup(warmup_audio) temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else: else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None) whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
return whisper_model return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self): def set_translate_task(self):
"""Set up translation task.""" """Set up translation task."""

View File

@@ -1,17 +1,32 @@
from torch import Tensor
from whisperlivekit.whisper.decoding import PyTorchInference from whisperlivekit.whisper.decoding import PyTorchInference
# extention of PyTorchInference for beam search
class BeamPyTorchInference(PyTorchInference):
def _kv_modules(self): class BeamPyTorchInference(PyTorchInference):
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks] """Extension of PyTorchInference for beam search with cross-attention support."""
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
return key_modules + value_modules def _kv_cache_ids(self):
"""Get cache_id strings for self-attention key/value modules."""
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
return key_ids + value_ids
def rearrange_kv_cache(self, source_indices): def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))): if source_indices != list(range(len(source_indices))):
for module_cache_id in self._kv_modules(): for cache_id in self._kv_cache_ids():
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach() if cache_id in self.kv_cache:
from torch import Tensor self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) def logits(
self,
tokens: Tensor,
audio_features: Tensor,
return_cross_attn: bool = False,
):
"""Get logits, optionally returning cross-attention weights."""
return self.model.decoder(
tokens, audio_features,
kv_cache=self.kv_cache,
return_cross_attn=return_cross_attn,
)

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
@dataclass @dataclass
class AlignAttConfig(): class AlignAttConfig():
eval_data_path: str = "tmp" eval_data_path: str = "tmp"

View File

@@ -0,0 +1,80 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
@dataclass
class DecoderState:
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[torch.Tensor] = field(default_factory=list)
initial_tokens: Optional[torch.Tensor] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[torch.Tensor] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
CIFLinear: Optional[torch.nn.Module] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens_fn: Any = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.kv_cache = {}
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
"""
Reset transient state for a new segment.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = {}
self.first_timestamp = None

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from mlx_whisper import whisper from mlx_whisper import whisper
mlx_model_mapping = { mlx_model_mapping = {

View File

@@ -1,33 +1,36 @@
import os
import logging import logging
import torch
import torch.nn.functional as F
import numpy as np
from whisperlivekit.whisper import DecodingOptions, tokenizer
from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from whisperlivekit.whisper.timing import median_filter
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os import os
from time import time from time import time
from .token_buffer import TokenBuffer from typing import List, Optional, Tuple
from whisperlivekit.backend_support import (
mlx_backend_available, import numpy as np
faster_backend_available, import torch
) import torch.nn.functional as F
from whisperlivekit.backend_support import (faster_backend_available,
mlx_backend_available)
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
TOKENS_PER_SECOND,
log_mel_spectrogram, pad_or_trim)
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
SuppressTokens)
from whisperlivekit.whisper.timing import median_filter
from ..timed_objects import PUNCTUATION_MARKS from ..timed_objects import PUNCTUATION_MARKS
from .beam import BeamPyTorchInference
from .config import AlignAttConfig
from .decoder_state import DecoderState
from .eow_detection import fire_at_boundary, load_cif
from .token_buffer import TokenBuffer
DEC_PAD = 50257 DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if mlx_backend_available(): if mlx_backend_available():
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram from mlx_whisper.audio import \
log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
if faster_backend_available(): if faster_backend_available():
@@ -52,6 +55,30 @@ def load_coreml_encoder():
class AlignAtt: class AlignAtt:
"""
Alignment-based Attention decoder for SimulStreaming.
This class is now hookless - the model can be shared across multiple
sessions, with each session maintaining its own DecoderState.
"""
# Property accessors for backward compatibility
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__( def __init__(
self, self,
cfg: AlignAttConfig, cfg: AlignAttConfig,
@@ -59,8 +86,7 @@ class AlignAtt:
mlx_encoder=None, mlx_encoder=None,
fw_encoder=None, fw_encoder=None,
) -> None: ) -> None:
self.log_segments = 0 # Shared model reference (can be shared across sessions)
self.model = loaded_model self.model = loaded_model
self.mlx_encoder = mlx_encoder self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder self.fw_encoder = fw_encoder
@@ -74,119 +100,89 @@ class AlignAtt:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions( self.decode_options = DecodingOptions(
language = cfg.language, language=cfg.language,
without_timestamps = True, without_timestamps=True,
task=cfg.task task=cfg.task
) )
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg self.cfg = cfg
self.l_hooks = []
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device)
# install hooks to access encoder-decoder attention
self.dec_attns = []
def layer_hook(module, net_input, net_output):
# net_output[1]: B*num_head*token_len*audio_len
t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks:
hook = b.cross_attn.register_forward_hook(layer_hook)
self.l_hooks.append(hook)
self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
# save as-is, for the first token or cross attention
self.kv_cache[module.cache_id] = net_output
else:
x = self.kv_cache[module.cache_id]
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks):
hooks = [
b.attn.key.register_forward_hook(kv_hook),
b.attn.value.register_forward_hook(kv_hook),
b.cross_attn.key.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook),
]
self.l_hooks.extend(hooks)
self.align_source = {}
self.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.align_source.get(layer_rank, [])
heads.append((self.num_align_heads, head_id.item()))
self.align_source[layer_rank] = heads
self.num_align_heads += 1
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
# self.tokenizer.eot
self.tokenizer.no_timestamps, # added by DM
] + list(self.tokenizer.all_language_tokens) # added by DM
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# it's going to be regenerated after lang id
self.segments = []
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.first_timestamp = None
if self.cfg.max_context_tokens is None: if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len self.max_context_tokens = self.max_text_len
else: else:
self.max_context_tokens = self.cfg.max_context_tokens self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = DecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
# Create tokenizer
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
# Timing state
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
# CIF helpers for end-of-word boundary detection
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device
)
# Build alignment source mapping from model's alignment_heads
self.state.align_source = {}
self.state.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id.item()))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
# Build suppress tokens function
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
# Initialize tokens
self.init_tokens()
self.init_context() self.init_context()
# decoder type: greedy or beam # Set up decoder type
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy": if cfg.decoder_type == "greedy":
logger.info("Using greedy decoder") logger.info("Using greedy decoder")
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot) self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
self.decoder_type = "greedy"
elif cfg.decoder_type == "beam": elif cfg.decoder_type == "beam":
self.decoder_type = "beam" logger.info("Using beam decoder")
self.inference = BeamPyTorchInference(self.model, self.initial_token_length) self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
self.inference.kv_cache = self.kv_cache self.state.inference.kv_cache = self.state.kv_cache
self.state.token_decoder = BeamSearchDecoder(
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) inference=self.state.inference,
eot=self.tokenizer.eot,
# Tokens to carry over to next chunk for incomplete UTF-8 characters beam_size=cfg.beam_size
self.pending_incomplete_tokens = [] )
def remove_hooks(self):
for hook in self.l_hooks:
hook.remove()
def warmup(self, audio): def warmup(self, audio):
try: try:
@@ -204,96 +200,100 @@ class AlignAtt:
num_languages=self.model.num_languages, num_languages=self.model.num_languages,
task=self.decode_options.task task=self.decode_options.task
) )
self.state.tokenizer = self.tokenizer
def init_context(self): def init_context(self):
kw = {'tokenizer': self.tokenizer, kw = {'tokenizer': self.tokenizer,
'device': self.model.device, 'device': self.model.device,
'prefix_token_ids': [self.tokenizer.sot_prev]} 'prefix_token_ids': [self.tokenizer.sot_prev]}
self.context = TokenBuffer.empty(**kw) self.state.context = TokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None: if self.cfg.static_init_prompt is not None:
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw) self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None: if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt self.state.context.text += self.cfg.init_prompt
def init_tokens(self): def init_tokens(self):
logger.debug(f"init tokens, {len(self.segments)}") logger.debug(f"init tokens, {len(self.state.segments)}")
# init tokens (mandatory prompt) # init tokens (mandatory prompt)
self.initial_tokens = torch.tensor( self.state.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps, self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long, dtype=torch.long,
device=self.model.device).unsqueeze(0) device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1] self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# self.segments = [] logger.debug(f"init tokens after, {len(self.state.segments)}")
logger.debug(f"init tokens after, {len(self.segments)}") self.state.tokens = [self.state.initial_tokens]
self.tokens = [self.initial_tokens]
def trim_context(self): def trim_context(self):
logger.info("Trimming context") logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}") logger.info(f"Context text: {self.state.context.as_text()}")
logger.info(f"Context text: {self.context.as_text()}") l = sum(t.shape[1] for t in self.state.tokens) + c
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
l = sum(t.shape[1] for t in self.tokens) + c
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if self.cfg.static_init_prompt is None: if self.cfg.static_init_prompt is None:
after = 0 after = 0
else: else:
after = len(self.cfg.static_init_prompt) after = len(self.cfg.static_init_prompt)
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
while c > self.max_context_tokens or l > self.max_text_len - 20: while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.context.trim_words(after=after) t = self.state.context.trim_words(after=after)
l -= t l -= t
c -= t c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0: if t == 0:
break break
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
logger.info(f"Context after trim: {self.context.text} (len: {l})")
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor: def logits(
if self.cfg.decoder_type == "greedy": self,
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) tokens: torch.Tensor,
audio_features: torch.Tensor,
return_cross_attn: bool = False
):
"""Get logits from decoder, optionally returning cross-attention weights."""
if self.state.decoder_type == "greedy":
return self.model.decoder(
tokens, audio_features,
kv_cache=self.state.kv_cache,
return_cross_attn=return_cross_attn
)
else: else:
logger.debug(f"Logits shape: {tokens.shape}") logger.debug(f"Logits shape: {tokens.shape}")
logit = self.inference.logits(tokens, audio_features) return self.state.inference.logits(
return logit tokens, audio_features,
return_cross_attn=return_cross_attn
)
def refresh_segment(self, complete=False): def refresh_segment(self, complete=False):
logger.debug("Refreshing segment:") logger.debug("Refreshing segment:")
self.init_tokens() self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold self.state.last_attend_frame = -self.cfg.rewind_threshold
# self.detected_language = None self.state.cumulative_time_offset = 0.0
self.cumulative_time_offset = 0.0
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.state.context}")
if not complete and len(self.segments) > 2: if not complete and len(self.state.segments) > 2:
self.segments = self.segments[-2:] self.state.segments = self.state.segments[-2:]
else: else:
logger.debug("removing all segments.") logger.debug("removing all segments.")
self.segments = [] self.state.segments = []
self.log_segments += 1 self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
self.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True if self.state.always_fire:
if self.never_fire: return False return True
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear) if self.state.never_fire:
return False
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
def _current_tokens(self): def _current_tokens(self):
toks = self.state.tokens
toks = self.tokens
# very first infer: duplicate start of seq to beam_size # very first infer: duplicate start of seq to beam_size
if toks[0].shape[0] == 1: if toks[0].shape[0] == 1:
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0) toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
if not self.context.is_empty(): if not self.state.context.is_empty():
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device) context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
toks = [context_toks] + toks toks = [context_toks] + toks
# make it one tensor # make it one tensor
@@ -313,7 +313,7 @@ class AlignAtt:
### audio buffer ### audio buffer
def segments_len(self): def segments_len(self):
segments_len = sum(s.shape[0] for s in self.segments) / 16000 segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
return segments_len return segments_len
def _apply_minseglen(self): def _apply_minseglen(self):
@@ -326,42 +326,36 @@ class AlignAtt:
def insert_audio(self, segment=None): def insert_audio(self, segment=None):
if segment is not None: if segment is not None:
self.segments.append(segment) self.state.segments.append(segment)
removed_len = 0 removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment # len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len() segments_len = self.segments_len()
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len: while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000 removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.cumulative_time_offset += removed_len # Track cumulative time removed self.state.cumulative_time_offset += removed_len # Track cumulative time removed
self.segments = self.segments[1:] self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s") logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1: if len(self.state.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:].tolist()) self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
self.tokens = [self.initial_tokens] + self.tokens[2:] self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len return removed_len
def _clean_cache(self): def _clean_cache(self):
'''clean the cache that stores the attention matrices and kv_cache. """Clean the kv_cache after each inference step."""
It must be called every time after generation with the model.''' self.state.clean_cache()
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
@torch.no_grad() @torch.no_grad()
def lang_id(self, encoder_features): def lang_id(self, encoder_features):
"""Language detection from encoder features. """Language detection from encoder features.
This code is trimmed and copy-pasted from whisper.decoding.detect_language . This code is trimmed and copy-pasted from whisper.decoding.detect_language.
""" """
# forward pass using a single token, startoftranscript # forward pass using a single token, startoftranscript
n_audio = encoder_features.shape[0] n_audio = encoder_features.shape[0]
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1] x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
# Note: don't use kv_cache for language detection
logits = self.model.logits(x, encoder_features)[:, 0] logits = self.model.logits(x, encoder_features)[:, 0]
# collect detected languages; suppress all non-language tokens # collect detected languages; suppress all non-language tokens
@@ -391,19 +385,19 @@ class AlignAtt:
@torch.no_grad() @torch.no_grad()
def infer(self, is_last=False): def infer(self, is_last=False):
new_segment = True new_segment = True
if len(self.segments) == 0: if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do") logger.debug("No segments, nothing to do")
return [] return []
if not self._apply_minseglen(): if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0) input_segments = torch.cat(self.state.segments, dim=0)
return [] return []
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
if len(self.segments) > 1: if len(self.state.segments) > 1:
input_segments = torch.cat(self.segments, dim=0) input_segments = torch.cat(self.state.segments, dim=0)
else: else:
input_segments = self.segments[0] input_segments = self.state.segments[0]
beg_encode = time() beg_encode = time()
if self.use_mlcore: if self.use_mlcore:
@@ -457,18 +451,18 @@ class AlignAtt:
end_encode = time() end_encode = time()
# print('Encoder duration:', end_encode-beg_encode) # print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp: if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.first_timestamp seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0: if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature) language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}") print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) self.create_tokenizer(top_lan)
self.last_attend_frame = -self.cfg.rewind_threshold self.state.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0 self.state.cumulative_time_offset = 0.0
self.init_tokens() self.init_tokens()
self.init_context() self.init_context()
self.detected_language = top_lan self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
self.trim_context() self.trim_context()
@@ -488,92 +482,80 @@ class AlignAtt:
l_absolute_timestamps = [] l_absolute_timestamps = []
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens accumulated_cross_attns = []
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
if new_segment: if new_segment:
tokens_for_logits = current_tokens tokens_for_logits = current_tokens
else: else:
# only need to use the last token except in the first forward pass # only need to use the last token except in the first forward pass
tokens_for_logits = current_tokens[:,-1:] tokens_for_logits = current_tokens[:, -1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size # Get logits and cross-attention weights from decoder
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
accumulated_cross_attns.append(cross_attns)
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1) probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob: if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop") logger.info("no speech, stop")
break break
logits = logits[:, -1, :] # logits for the last token logits = logits[:, -1, :] # logits for the last token
# supress blank tokens only at the beginning of the segment # suppress blank tokens only at the beginning of the segment
if new_segment: if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False new_segment = False
self.suppress_tokens(logits) self.state.suppress_tokens_fn(logits)
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs) current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens) self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] # Process accumulated cross-attention weights for alignment
for i, attn_mat in enumerate(self.dec_attns): attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
layer_rank = int(i % len(self.model.decoder.blocks))
align_heads_in_layer = self.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0)
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# for each beam, the most attended frame is: # for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1) most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
# Calculate absolute timestamps accounting for cumulative offset # Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()] absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames.tolist()
]
logger.debug(str(most_attended_frames.tolist()) + " most att frames") logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)") logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item() most_attended_frame = most_attended_frames[0].item()
l_absolute_timestamps.append(absolute_timestamps[0]) l_absolute_timestamps.append(absolute_timestamps[0])
logger.debug("current tokens" + str(current_tokens.shape)) logger.debug("current tokens" + str(current_tokens.shape))
if completed: if completed:
# # stripping the last token, the eot # stripping the last token, the eot
current_tokens = current_tokens[:, :-1] current_tokens = current_tokens[:, :-1]
break break
# for some rare cases where the attention fails # for some rare cases where the attention fails
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
# TODO: check this
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD: if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
logger.debug("ommit rewinding from special tokens") logger.debug("omit rewinding from special tokens")
self.last_attend_frame = most_attended_frame self.state.last_attend_frame = most_attended_frame
else: else:
logger.debug( logger.debug(
f"[rewind detected] current attention pos: {most_attended_frame}, " f"[rewind detected] current attention pos: {most_attended_frame}, "
f"last attention pos: {self.last_attend_frame}; omit this segment") f"last attention pos: {self.state.last_attend_frame}; omit this segment")
self.last_attend_frame = -self.cfg.rewind_threshold self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0] current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break break
else: else:
self.last_attend_frame = most_attended_frame self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold): if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}") logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
@@ -593,12 +575,12 @@ class AlignAtt:
tokens_to_split = current_tokens[0, token_len_before_decoding:] tokens_to_split = current_tokens[0, token_len_before_decoding:]
# Prepend pending tokens from previous chunk if any # Prepend pending tokens from previous chunk if any
if self.pending_incomplete_tokens: if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}") logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device) pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
tokens_to_split = torch.cat([pending_tensor, tokens_to_split]) tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
if fire_detected or is_last: #or punctuation_stop: if fire_detected or is_last:
new_hypothesis = tokens_to_split.flatten().tolist() new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else: else:
@@ -609,20 +591,18 @@ class AlignAtt:
else: else:
new_hypothesis = [] new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}") logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.device, device=self.device,
) )
self.tokens.append(new_tokens) self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache() self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None: if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0] self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = [] timestamped_words = []
timestamp_idx = 0 timestamp_idx = 0
@@ -641,20 +621,85 @@ class AlignAtt:
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(
start=round(current_timestamp, 2), start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2), end=round(current_timestamp + 0.1, 2),
text= word, text=word,
speaker=self.speaker, speaker=self.state.speaker,
detected_language=self.detected_language detected_language=self.state.detected_language
).with_offset( ).with_offset(
self.global_time_offset self.state.global_time_offset
) )
timestamped_words.append(timestamp_entry) timestamped_words.append(timestamp_entry)
# Hold incomplete tokens for next chunk # Hold incomplete tokens for next chunk
self.pending_incomplete_tokens = [] self.state.pending_incomplete_tokens = []
if split_words and replacement_char in split_words[-1]: if split_words and replacement_char in split_words[-1]:
self.pending_incomplete_tokens = split_tokens[-1] self.state.pending_incomplete_tokens = split_tokens[-1]
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}") logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}")
return timestamped_words return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[torch.Tensor],
content_mel_len: int
) -> torch.Tensor:
"""
Process cross-attention weights from decoder layers for alignment.
Args:
cross_attns: List of cross-attention tensors from each decoder layer.
Each tensor has shape (batch, n_head, seq_len, audio_len)
content_mel_len: Length of actual audio content in mel frames
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = len(self.model.decoder.blocks)
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
layer_rank = idx % num_decoder_layers
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = F.softmax(attn_mat, dim=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
# (n_head, seq_len, audio_len) when squeezed
if attn_mat.dim() == 4:
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
else:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0) # (1, seq_len, audio_len)
else:
# attn_mat: (batch, n_head, seq_len, audio_len)
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
tmp.append(t)
if not tmp:
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
return attn_of_alignment_heads

View File

@@ -1,5 +1,8 @@
import torch
import sys import sys
import torch
class TokenBuffer: class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]): def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Union, Dict, Any
from datetime import timedelta from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''} PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}

View File

@@ -1,7 +1,9 @@
from time import time from time import time
from typing import Optional, List, Tuple, Union, Any from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
SilentLine, SpeakerSegment,
TimedText)
class TokensAlignment: class TokensAlignment:

View File

@@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5):
import os import os
import tempfile import tempfile
import urllib.request import urllib.request
import librosa import librosa
if warmup_file == "": if warmup_file == "":

View File

@@ -1,6 +1,6 @@
import logging
import importlib.resources as resources
import base64 import base64
import importlib.resources as resources
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -96,11 +96,13 @@ def get_inline_ui_html():
if __name__ == '__main__': if __name__ == '__main__':
import pathlib
import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
import uvicorn
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg import whisperlivekit.web as webpkg
app = FastAPI() app = FastAPI()

View File

@@ -4,15 +4,17 @@ import json
import os import os
import urllib import urllib
import warnings import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch import torch
from tqdm import tqdm
from pathlib import Path
from torch import Tensor from torch import Tensor
from tqdm import tqdm
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language pad_or_trim)
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
decode, detect_language)
from whisperlivekit.whisper.model import ModelDimensions, Whisper from whisperlivekit.whisper.model import ModelDimensions, Whisper
from whisperlivekit.whisper.transcribe import transcribe from whisperlivekit.whisper.transcribe import transcribe
from whisperlivekit.whisper.version import __version__ from whisperlivekit.whisper.version import __version__

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
Tuple, Union)
import numpy as np import numpy as np
import torch import torch
@@ -146,16 +147,13 @@ class PyTorchInference(Inference):
self.model: "Whisper" = model self.model: "Whisper" = model
self.initial_token_length = initial_token_length self.initial_token_length = initial_token_length
self.kv_cache = {} self.kv_cache = {}
self.hooks = []
key_modules = [block.attn.key for block in self.model.decoder.blocks] self.kv_cache_ids = []
value_modules = [block.attn.value for block in self.model.decoder.blocks] for block in self.model.decoder.blocks:
self.kv_modules = key_modules + value_modules self.kv_cache_ids.append(block.attn.key_cache_id)
self.kv_cache_ids.append(block.attn.value_cache_id)
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length: if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass # only need to use the last token except in the first forward pass
tokens = tokens[:, -1:] tokens = tokens[:, -1:]
@@ -163,17 +161,14 @@ class PyTorchInference(Inference):
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self): def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {} self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices): def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))): if source_indices != list(range(len(source_indices))):
for module in self.kv_modules: for cache_id in self.kv_cache_ids:
# update the key/value cache to contain the selected sequences if cache_id in self.kv_cache:
self.kv_cache[module] = self.kv_cache[module][source_indices].detach() # update the key/value cache to contain the selected sequences
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
class SequenceRanker: class SequenceRanker:

View File

@@ -79,18 +79,23 @@ def disable_sdpa():
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks use_sdpa = False # Disable SDPA to ensure qk is always computed when needed
def __init__(self, n_state: int, n_head: int, cache_id: str = ""): def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
self.n_text_ctx = n_text_ctx
self.query = Linear(n_state, n_state) self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False) self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state) self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state) self.out = Linear(n_state, n_state)
self.cache_id = cache_id self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key" # Cache IDs for key and value (used with dict-based kv_cache)
self.value.cache_id = f"{cache_id}_value" self.key_cache_id = f"{cache_id}_key"
self.value_cache_id = f"{cache_id}_value"
# Keep these for backward compatibility with hook-based caching
self.key.cache_id = self.key_cache_id
self.value.cache_id = self.value_cache_id
def forward( def forward(
self, self,
@@ -101,19 +106,45 @@ class MultiHeadAttention(nn.Module):
): ):
q = self.query(x) q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache: if xa is None:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; # Self-attention
# otherwise, perform key/value projections for self- or cross-attention as usual. k = self.key(x)
k = self.key(x if xa is None else xa) v = self.value(x)
v = self.value(x if xa is None else xa) if kv_cache is not None:
k, v = self._update_self_attn_cache(k, v, kv_cache)
else: else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls. # Cross-attention: compute once and cache, or reuse from cache
k = kv_cache[self.key] if kv_cache is not None and self.key_cache_id in kv_cache:
v = kv_cache[self.value] k = kv_cache[self.key_cache_id]
v = kv_cache[self.value_cache_id]
else:
k = self.key(xa)
v = self.value(xa)
if kv_cache is not None:
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk return self.out(wv), qk
def _update_self_attn_cache(
self, k: Tensor, v: Tensor, kv_cache: dict
) -> Tuple[Tensor, Tensor]:
"""Update self-attention kv cache by concatenating new k,v with cached values."""
if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx:
# First token or context overflow: save as-is
kv_cache[self.key_cache_id] = k.detach()
kv_cache[self.value_cache_id] = v.detach()
else:
# Concatenate with existing cache
cached_k = kv_cache[self.key_cache_id]
cached_v = kv_cache[self.value_cache_id]
k = torch.cat([cached_k, k], dim=1).detach()
v = torch.cat([cached_v, v], dim=1).detach()
kv_cache[self.key_cache_id] = k
kv_cache[self.value_cache_id] = v
return k, v
def qkv_attention( def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -143,14 +174,21 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""): def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False,
cache_id: str = "", n_text_ctx: int = 448
):
super().__init__() super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn") self.attn = MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx
)
self.attn_ln = LayerNorm(n_state) self.attn_ln = LayerNorm(n_state)
self.cross_attn = ( self.cross_attn = (
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None MultiHeadAttention(
n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx
) if cross_attention else None
) )
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -166,12 +204,21 @@ class ResidualAttentionBlock(nn.Module):
xa: Optional[Tensor] = None, xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
): ) -> Tuple[Tensor, Optional[Tensor]]:
"""
Returns:
x: The output tensor
cross_attn_qk: Cross-attention weights (if cross_attn exists), else None
"""
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
cross_attn_qk = None
if self.cross_attn: if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] cross_out, cross_attn_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=kv_cache
)
x = x + cross_out
x = x + self.mlp(self.mlp_ln(x)) x = x + self.mlp(self.mlp_ln(x))
return x return x, cross_attn_qk
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
@@ -201,7 +248,7 @@ class AudioEncoder(nn.Module):
x = (x + self.positional_embedding).to(x.dtype) x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks: for block in self.blocks:
x = block(x) x, _ = block(x) # Encoder blocks don't have cross-attention
x = self.ln_post(x) x = self.ln_post(x)
return x return x
@@ -212,13 +259,17 @@ class TextDecoder(nn.Module):
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
): ):
super().__init__() super().__init__()
self.n_ctx = n_ctx
self.token_embedding = nn.Embedding(n_vocab, n_state) self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ [
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}") ResidualAttentionBlock(
n_state, n_head, cross_attention=True,
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
)
for i in range(n_layer) for i in range(n_layer)
] ]
) )
@@ -227,28 +278,57 @@ class TextDecoder(nn.Module):
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): def forward(
self,
x: Tensor,
xa: Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
""" """
x : torch.LongTensor, shape = (batch_size, <= n_ctx) x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on the encoded audio features to be attended on
kv_cache : Optional[dict]
Dictionary to store/retrieve key-value cache for efficient decoding
return_cross_attn : bool
If True, return cross-attention weights from all decoder layers
Returns
-------
logits : Tensor
The output logits
cross_attns : Optional[List[Tensor]]
List of cross-attention weights per layer (only if return_cross_attn=True)
""" """
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 # Calculate offset from self-attention cache (not cross-attention which has audio length)
offset = 0
if kv_cache:
# Use the first decoder block's self-attention key cache to get token position
first_self_attn_key = self.blocks[0].attn.key_cache_id
if first_self_attn_key in kv_cache:
offset = kv_cache[first_self_attn_key].shape[1]
x = ( x = (
self.token_embedding(x) self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]] + self.positional_embedding[offset : offset + x.shape[-1]]
) )
x = x.to(xa.dtype) x = x.to(xa.dtype)
cross_attns = [] if return_cross_attn else None
for block in self.blocks: for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache)
if return_cross_attn and cross_attn_qk is not None:
cross_attns.append(cross_attn_qk)
x = self.ln(x) x = self.ln(x)
logits = ( logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float() ).float()
if return_cross_attn:
return logits, cross_attns
return logits return logits
@@ -292,8 +372,18 @@ class Whisper(nn.Module):
def embed_audio(self, mel: torch.Tensor): def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel) return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): def logits(
return self.decoder(tokens, audio_features) self,
tokens: torch.Tensor,
audio_features: torch.Tensor,
kv_cache: Optional[dict] = None,
return_cross_attn: bool = False,
):
return self.decoder(
tokens, audio_features,
kv_cache=kv_cache,
return_cross_attn=return_cross_attn
)
def forward( def forward(
self, mel: torch.Tensor, tokens: torch.Tensor self, mel: torch.Tensor, tokens: torch.Tensor
@@ -312,39 +402,6 @@ class Whisper(nn.Module):
def num_languages(self): def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual) return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function detect_language = detect_language_function
transcribe = transcribe_function transcribe = transcribe_function
decode = decode_function decode = decode_function

View File

@@ -8,28 +8,13 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import ( from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
FRAMES_PER_SECOND, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (exact_div, format_timestamp, get_end, get_writer,
exact_div, make_safe, optional_float, optional_int, str2bool)
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper