mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c62f8f84 | ||
|
|
7faa21f95f | ||
|
|
4e9f951551 | ||
|
|
870141298c | ||
|
|
872faa422a | ||
|
|
fc9cb66813 | ||
|
|
a175d1a327 |
19
README.md
19
README.md
@@ -51,9 +51,11 @@ pip install whisperlivekit
|
|||||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||||
|
|
||||||
|
|
||||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
> - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
|
||||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||||
|
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||||
|
|
||||||
|
|
||||||
#### Use it to capture audio from web pages.
|
#### Use it to capture audio from web pages.
|
||||||
|
|
||||||
@@ -96,11 +98,13 @@ wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
|||||||
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import asyncio
|
from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
|
||||||
|
|
||||||
transcription_engine = None
|
transcription_engine = None
|
||||||
|
|
||||||
@@ -146,8 +150,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
|
||||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
| `--host` | Server host address | `localhost` |
|
| `--host` | Server host address | `localhost` |
|
||||||
| `--port` | Server port | `8000` |
|
| `--port` | Server port | `8000` |
|
||||||
@@ -183,7 +187,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
258
ReadmeJP.md
258
ReadmeJP.md
@@ -1,258 +0,0 @@
|
|||||||
<h1 align="center">WhisperLiveKit</h1>
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
|
||||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
|
||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
|
||||||
|
|
||||||
#### 主要な研究による技術:
|
|
||||||
|
|
||||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
|
||||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
|
||||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
|
||||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
|
||||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
|
||||||
|
|
||||||
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
|
||||||
|
|
||||||
### アーキテクチャ
|
|
||||||
|
|
||||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
|
||||||
|
|
||||||
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
|
||||||
|
|
||||||
### インストールとクイックスタート
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install whisperlivekit
|
|
||||||
```
|
|
||||||
|
|
||||||
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
|
||||||
>
|
|
||||||
> | OS | インストール方法 |
|
|
||||||
> |-----------|-------------|
|
|
||||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
|
||||||
> | MacOS | `brew install ffmpeg` |
|
|
||||||
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
|
||||||
|
|
||||||
#### クイックスタート
|
|
||||||
1. **文字起こしサーバーを起動します:**
|
|
||||||
```bash
|
|
||||||
whisperlivekit-server --model base --language en
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
|
||||||
|
|
||||||
|
|
||||||
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
|
||||||
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
|
||||||
|
|
||||||
#### オプションの依存関係
|
|
||||||
|
|
||||||
| オプション | `pip install` |
|
|
||||||
|-----------|-------------|
|
|
||||||
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
|
||||||
| Diartによる話者ダイアライゼーション | `diart` |
|
|
||||||
| オリジナルのWhisperバックエンド | `whisper` |
|
|
||||||
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
|
||||||
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
|
||||||
| OpenAI APIバックエンド | `openai` |
|
|
||||||
|
|
||||||
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
|
||||||
|
|
||||||
### 使用例
|
|
||||||
|
|
||||||
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# デフォルト(small)より良いモデルを使用
|
|
||||||
whisperlivekit-server --model large-v3
|
|
||||||
|
|
||||||
# ダイアライゼーションと言語を指定した高度な設定
|
|
||||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
|
||||||
```
|
|
||||||
|
|
||||||
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
transcription_engine = None
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
global transcription_engine
|
|
||||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
|
||||||
yield
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
|
||||||
async for response in results_generator:
|
|
||||||
await websocket.send_json(response)
|
|
||||||
await websocket.send_json({"type": "ready_to_stop"})
|
|
||||||
|
|
||||||
@app.websocket("/asr")
|
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
|
||||||
global transcription_engine
|
|
||||||
|
|
||||||
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
|
||||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
|
||||||
results_generator = await audio_processor.create_tasks()
|
|
||||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
|
||||||
await websocket.accept()
|
|
||||||
while True:
|
|
||||||
message = await websocket.receive_bytes()
|
|
||||||
await audio_processor.process_audio(message)
|
|
||||||
```
|
|
||||||
|
|
||||||
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
|
||||||
|
|
||||||
|
|
||||||
## パラメータと設定
|
|
||||||
|
|
||||||
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
|
||||||
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
|
||||||
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
|
||||||
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
|
||||||
- `--warmup-file`、もしあれば
|
|
||||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
|
||||||
- `--diarization`、使用したい場合。
|
|
||||||
|
|
||||||
残りは推奨しません。しかし、以下があなたのオプションです。
|
|
||||||
|
|
||||||
| パラメータ | 説明 | デフォルト |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--model` | Whisperモデルのサイズ。 | `small` |
|
|
||||||
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
|
||||||
| `--task` | `transcribe`または`translate` | `transcribe` |
|
|
||||||
| `--backend` | 処理バックエンド | `simulstreaming` |
|
|
||||||
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
|
||||||
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
|
||||||
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
|
||||||
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
|
||||||
| `--host` | サーバーホストアドレス | `localhost` |
|
|
||||||
| `--port` | サーバーポート | `8000` |
|
|
||||||
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
|
||||||
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
|
||||||
|
|
||||||
|
|
||||||
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
|
||||||
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
|
||||||
|
|
||||||
|
|
||||||
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
|
||||||
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
|
||||||
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
|
||||||
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
|
||||||
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
|
||||||
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
|
||||||
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
|
||||||
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
|
||||||
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
|
||||||
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
|
||||||
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
|
||||||
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
|
||||||
|
|
||||||
| ダイアライゼーションオプション | 説明 | デフォルト |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--diarization` | 話者識別を有効化 | `False` |
|
|
||||||
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
|
||||||
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
|
||||||
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
|
||||||
|
|
||||||
|
|
||||||
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
|
||||||
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
|
||||||
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
|
||||||
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
|
||||||
>4. HuggingFaceでログイン: `huggingface-cli login`
|
|
||||||
|
|
||||||
### 🚀 デプロイガイド
|
|
||||||
|
|
||||||
WhisperLiveKitを本番環境にデプロイするには:
|
|
||||||
|
|
||||||
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
|
||||||
```bash
|
|
||||||
pip install uvicorn gunicorn
|
|
||||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
|
||||||
|
|
||||||
3. **Nginx設定** (本番環境で推奨):
|
|
||||||
```nginx
|
|
||||||
server {
|
|
||||||
listen 80;
|
|
||||||
server_name your-domain.com;
|
|
||||||
location / {
|
|
||||||
proxy_pass http://localhost:8000;
|
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
|
||||||
proxy_set_header Connection "upgrade";
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
|
||||||
|
|
||||||
## 🐋 Docker
|
|
||||||
|
|
||||||
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
|
||||||
|
|
||||||
### 前提条件
|
|
||||||
- Dockerがシステムにインストールされていること
|
|
||||||
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
|
||||||
|
|
||||||
### クイックスタート
|
|
||||||
|
|
||||||
**GPUアクセラレーション付き (推奨):**
|
|
||||||
```bash
|
|
||||||
docker build -t wlk .
|
|
||||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
|
||||||
```
|
|
||||||
|
|
||||||
**CPUのみ:**
|
|
||||||
```bash
|
|
||||||
docker build -f Dockerfile.cpu -t wlk .
|
|
||||||
docker run -p 8000:8000 --name wlk wlk
|
|
||||||
```
|
|
||||||
|
|
||||||
### 高度な使用法
|
|
||||||
|
|
||||||
**カスタム設定:**
|
|
||||||
```bash
|
|
||||||
# カスタムモデルと言語の例
|
|
||||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
|
||||||
```
|
|
||||||
|
|
||||||
### メモリ要件
|
|
||||||
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
|
||||||
|
|
||||||
|
|
||||||
#### カスタマイズ
|
|
||||||
|
|
||||||
- `--build-arg` オプション:
|
|
||||||
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
|
||||||
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
|
||||||
|
|
||||||
## 🔮 ユースケース
|
|
||||||
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
|
||||||
113
docs/troubleshooting.md
Normal file
113
docs/troubleshooting.md
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# Troubleshooting
|
||||||
|
|
||||||
|
|
||||||
|
## GPU drivers & cuDNN visibility
|
||||||
|
|
||||||
|
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
|
||||||
|
> Reported in issue #271 (Arch/CachyOS)
|
||||||
|
|
||||||
|
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
|
||||||
|
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
|
||||||
|
|
||||||
|
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
|
||||||
|
```bash
|
||||||
|
sudo pacman -S cuda cudnn
|
||||||
|
sudo ldconfig
|
||||||
|
```
|
||||||
|
2. **Make sure the shared objects are visible**:
|
||||||
|
```bash
|
||||||
|
ls /usr/lib/libcudnn*
|
||||||
|
```
|
||||||
|
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
|
||||||
|
```bash
|
||||||
|
python - <<'EOF'
|
||||||
|
import torch
|
||||||
|
print(torch.version.cuda)
|
||||||
|
EOF
|
||||||
|
nvcc --version
|
||||||
|
```
|
||||||
|
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
|
||||||
|
|
||||||
|
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
|
||||||
|
|
||||||
|
### Windows error: `Could not locate cudnn_ops64_9.dll`
|
||||||
|
> Reported in issue #286 (Conda on Windows)
|
||||||
|
|
||||||
|
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
|
||||||
|
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
|
||||||
|
|
||||||
|
1. Locate the DLL shipped with PyTorch, e.g.
|
||||||
|
```
|
||||||
|
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
|
||||||
|
```
|
||||||
|
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
|
||||||
|
3. Restart the shell before launching `wlk`.
|
||||||
|
|
||||||
|
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## PyTorch / CTranslate2 GPU builds
|
||||||
|
|
||||||
|
### `Torch not compiled with CUDA enabled`
|
||||||
|
> Reported in issue #284
|
||||||
|
|
||||||
|
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
|
||||||
|
Install the GPU-enabled wheels that match your CUDA toolkit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
|
||||||
|
Validate with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
print(torch.cuda.is_available(), torch.cuda.get_device_name())
|
||||||
|
```
|
||||||
|
|
||||||
|
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
|
||||||
|
> Follow-up in issue #284
|
||||||
|
|
||||||
|
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
|
||||||
|
|
||||||
|
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
|
||||||
|
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
|
||||||
|
```bash
|
||||||
|
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
|
||||||
|
```
|
||||||
|
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
|
||||||
|
3. Verify:
|
||||||
|
```python
|
||||||
|
import ctranslate2
|
||||||
|
print("CUDA devices:", ctranslate2.get_cuda_device_count())
|
||||||
|
```
|
||||||
|
|
||||||
|
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Hopper / Blackwell (`sm_121a`) systems
|
||||||
|
> Reported in issue #276 (NVIDIA DGX Spark)
|
||||||
|
|
||||||
|
CUDA 12.1a GPUs ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual hints:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_HOME="/usr/local/cuda-13.0"
|
||||||
|
export PATH="$CUDA_HOME/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
# Tell Triton where the new ptxas lives
|
||||||
|
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
|
||||||
|
|
||||||
|
# Force PyTorch to JIT kernels for all needed architectures
|
||||||
|
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
|
||||||
|
```
|
||||||
|
|
||||||
|
After exporting those variables (or adding them to your systemd service / shell profile), restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.
|
||||||
|
|
||||||
@@ -61,10 +61,10 @@ packages = [
|
|||||||
"whisperlivekit.whisper.normalizers",
|
"whisperlivekit.whisper.normalizers",
|
||||||
"whisperlivekit.web",
|
"whisperlivekit.web",
|
||||||
"whisperlivekit.local_agreement",
|
"whisperlivekit.local_agreement",
|
||||||
"whisperlivekit.vad_models"
|
"whisperlivekit.silero_vad_models"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ from typing import Dict, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||||
from whisperlivekit.whisper.model import ModelDimensions
|
from whisperlivekit.whisper.model import ModelDimensions
|
||||||
from whisperlivekit.whisper.utils import exact_div
|
from whisperlivekit.whisper.utils import exact_div
|
||||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -5,16 +5,18 @@ import argparse
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
import io
|
import io
|
||||||
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import math
|
|
||||||
from typing import List, Optional, Sequence, Tuple, Union
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from datasets import Audio as DatasetAudio, load_dataset
|
|
||||||
import soundfile as sf
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from datasets import Audio as DatasetAudio
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Copy core files from web directory to Chrome extension directory."""
|
"""Copy core files from web directory to Chrome extension directory."""
|
||||||
|
|
||||||
import shutil
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def sync_extension_files():
|
def sync_extension_files():
|
||||||
|
|
||||||
web_dir = Path("whisperlivekit/web")
|
web_dir = Path("whisperlivekit/web")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .core import TranscriptionEngine
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
from .web.web_interface import get_inline_ui_html, get_web_interface_html
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TranscriptionEngine",
|
"TranscriptionEngine",
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np
|
|
||||||
from time import time
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Union, List, Any, AsyncGenerator
|
from time import time
|
||||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
import numpy as np
|
||||||
|
|
||||||
|
from whisperlivekit.core import (TranscriptionEngine,
|
||||||
|
online_diarization_factory, online_factory,
|
||||||
|
online_translation_factory)
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
|
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||||
|
Line, Silence, State, Transcript)
|
||||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@@ -603,16 +609,16 @@ class AudioProcessor:
|
|||||||
res = self.vac(pcm_array)
|
res = self.vac(pcm_array)
|
||||||
|
|
||||||
if res is not None:
|
if res is not None:
|
||||||
silence_detected = res.get("end", 0) > res.get("start", 0)
|
if "start" in res and self.current_silence:
|
||||||
if silence_detected and not self.current_silence:
|
await self._end_silence()
|
||||||
|
|
||||||
|
if "end" in res and not self.current_silence:
|
||||||
pre_silence_chunk = self._slice_before_silence(
|
pre_silence_chunk = self._slice_before_silence(
|
||||||
pcm_array, chunk_sample_start, res.get("end")
|
pcm_array, chunk_sample_start, res.get("end")
|
||||||
)
|
)
|
||||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||||
await self._enqueue_active_audio(pre_silence_chunk)
|
await self._enqueue_active_audio(pre_silence_chunk)
|
||||||
await self._begin_silence()
|
await self._begin_silence()
|
||||||
elif self.current_silence:
|
|
||||||
await self._end_silence()
|
|
||||||
|
|
||||||
if not self.current_silence:
|
if not self.current_silence:
|
||||||
await self._enqueue_active_audio(pcm_array)
|
await self._enqueue_active_audio(pcm_array)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from contextlib import asynccontextmanager
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
||||||
|
get_inline_ui_html, parse_args)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
|
||||||
from argparse import Namespace
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
|
|
||||||
def update_with_kwargs(_dict, kwargs):
|
def update_with_kwargs(_dict, kwargs):
|
||||||
_dict.update({
|
_dict.update({
|
||||||
@@ -80,6 +82,7 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
if self.args.vac:
|
if self.args.vac:
|
||||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||||
|
|
||||||
# Use ONNX if specified, otherwise use JIT (default)
|
# Use ONNX if specified, otherwise use JIT (default)
|
||||||
use_onnx = kwargs.get('vac_onnx', False)
|
use_onnx = kwargs.get('vac_onnx', False)
|
||||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||||
@@ -100,7 +103,6 @@ class TranscriptionEngine:
|
|||||||
"init_prompt": None,
|
"init_prompt": None,
|
||||||
"static_init_prompt": None,
|
"static_init_prompt": None,
|
||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"preload_model_count": 1,
|
|
||||||
}
|
}
|
||||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||||
|
|
||||||
@@ -135,7 +137,8 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
if self.args.diarization_backend == "diart":
|
if self.args.diarization_backend == "diart":
|
||||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
from whisperlivekit.diarization.diart_backend import \
|
||||||
|
DiartDiarization
|
||||||
diart_params = {
|
diart_params = {
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
"embedding_model": "pyannote/embedding",
|
"embedding_model": "pyannote/embedding",
|
||||||
@@ -146,7 +149,8 @@ class TranscriptionEngine:
|
|||||||
**diart_params
|
**diart_params
|
||||||
)
|
)
|
||||||
elif self.args.diarization_backend == "sortformer":
|
elif self.args.diarization_backend == "sortformer":
|
||||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
from whisperlivekit.diarization.sortformer_backend import \
|
||||||
|
SortformerDiarization
|
||||||
self.diarization_model = SortformerDiarization()
|
self.diarization_model = SortformerDiarization()
|
||||||
|
|
||||||
self.translation_model = None
|
self.translation_model = None
|
||||||
@@ -182,7 +186,8 @@ def online_diarization_factory(args, diarization_backend):
|
|||||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||||
|
|
||||||
if args.diarization_backend == "sortformer":
|
if args.diarization_backend == "sortformer":
|
||||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
from whisperlivekit.diarization.sortformer_backend import \
|
||||||
|
SortformerDiarizationOnline
|
||||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||||
return online
|
return online
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import numpy as np
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from queue import SimpleQueue, Empty
|
from queue import Empty, SimpleQueue
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import diart.models as m
|
||||||
|
import numpy as np
|
||||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||||
from diart.inference import StreamingInference
|
from diart.inference import StreamingInference
|
||||||
from diart.sources import AudioSource
|
from diart.sources import AudioSource, MicrophoneAudioSource
|
||||||
from whisperlivekit.timed_objects import SpeakerSegment
|
|
||||||
from diart.sources import MicrophoneAudioSource
|
|
||||||
from rx.core import Observer
|
|
||||||
from typing import Tuple, Any, List
|
|
||||||
from pyannote.core import Annotation
|
from pyannote.core import Annotation
|
||||||
import diart.models as m
|
from rx.core import Observer
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
|
from queue import Empty, SimpleQueue
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from queue import SimpleQueue, Empty
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import SpeakerSegment
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
|
||||||
@@ -295,6 +296,7 @@ def extract_number(s: str) -> int:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Callable
|
from typing import Callable, Optional
|
||||||
import contextlib
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import sys
|
|
||||||
import logging
|
|
||||||
import io
|
import io
|
||||||
import soundfile as sf
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
|
||||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
class ASRBase:
|
class ASRBase:
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
@@ -165,8 +168,8 @@ class MLXWhisper(ASRBase):
|
|||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
resolved_path = resolve_model_path(model_dir)
|
resolved_path = resolve_model_path(model_dir)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Optional
|
import sys
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import librosa
|
|
||||||
from functools import lru_cache
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
import sys
|
||||||
|
import time
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
|
mlx_backend_available)
|
||||||
|
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||||
from whisperlivekit.warmup import warmup_asr
|
from whisperlivekit.warmup import warmup_asr
|
||||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
|
||||||
from whisperlivekit.backend_support import (
|
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
||||||
mlx_backend_available,
|
|
||||||
faster_backend_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -295,14 +296,6 @@ def parse_args():
|
|||||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||||
)
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
|
||||||
"--preload-model-count",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
dest="preload_model_count",
|
|
||||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
|
||||||
)
|
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--nllb-backend",
|
"--nllb-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
"""
|
"""
|
||||||
@@ -123,7 +124,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
|||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
data_dir = current_dir / 'vad_models'
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
|
|
||||||
if onnx:
|
if onnx:
|
||||||
if opset_version == 16:
|
if opset_version == 16:
|
||||||
@@ -138,7 +139,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
|||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Model file not found: {model_path}\n"
|
f"Model file not found: {model_path}\n"
|
||||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
@@ -276,8 +277,10 @@ class FixedVADIterator(VADIterator):
|
|||||||
elif r is not None:
|
elif r is not None:
|
||||||
if "end" in r:
|
if "end" in r:
|
||||||
ret["end"] = r["end"]
|
ret["end"] = r["end"]
|
||||||
if "start" in r and "end" in ret:
|
if "start" in r:
|
||||||
del ret["end"]
|
ret["start"] = r["start"]
|
||||||
|
if "end" in ret:
|
||||||
|
del ret["end"]
|
||||||
return ret if ret != {} else None
|
return ret if ret != {} else None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,30 @@
|
|||||||
import sys
|
import gc
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Optional
|
import os
|
||||||
import platform
|
import platform
|
||||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
|
mlx_backend_available)
|
||||||
|
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||||
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
|
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||||
from whisperlivekit.warmup import load_file
|
from whisperlivekit.warmup import load_file
|
||||||
from whisperlivekit.whisper import load_model, tokenizer
|
from whisperlivekit.whisper import load_model, tokenizer
|
||||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||||
import os
|
|
||||||
import gc
|
|
||||||
from pathlib import Path
|
|
||||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
|
||||||
from whisperlivekit.backend_support import (
|
|
||||||
mlx_backend_available,
|
|
||||||
faster_backend_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
||||||
else:
|
else:
|
||||||
mlx_model_mapping = {}
|
mlx_model_mapping = {}
|
||||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||||
@@ -50,20 +49,19 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
self.load_new_backend()
|
self.load_new_alignatt_instance()
|
||||||
|
|
||||||
#can be moved
|
|
||||||
if asr.tokenizer:
|
if asr.tokenizer:
|
||||||
self.model.tokenizer = asr.tokenizer
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
def load_new_backend(self):
|
def load_new_alignatt_instance(self):
|
||||||
model = self.asr.get_new_model_instance()
|
"""Initialize AlignAtt decoder using the shared model."""
|
||||||
self.model = AlignAtt(
|
self.model = AlignAtt(
|
||||||
cfg=self.asr.cfg,
|
cfg=self.asr.cfg,
|
||||||
loaded_model=model,
|
loaded_model=self.asr.shared_model,
|
||||||
mlx_encoder=self.asr.mlx_encoder,
|
mlx_encoder=self.asr.mlx_encoder,
|
||||||
fw_encoder=self.asr.fw_encoder,
|
fw_encoder=self.asr.fw_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_silence(self):
|
def start_silence(self):
|
||||||
tokens, processed_upto = self.process_iter(is_last=True)
|
tokens, processed_upto = self.process_iter(is_last=True)
|
||||||
@@ -71,7 +69,10 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def end_silence(self, silence_duration, offset):
|
def end_silence(self, silence_duration, offset):
|
||||||
"""
|
"""
|
||||||
If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
Handle silence period.
|
||||||
|
|
||||||
|
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
||||||
|
Otherwise, insert a small silence and shift the last_attend_frame.
|
||||||
"""
|
"""
|
||||||
self.end += silence_duration
|
self.end += silence_duration
|
||||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||||
@@ -84,21 +85,20 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
self.model.global_time_offset = silence_duration + offset
|
self.model.global_time_offset = silence_duration + offset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
|
||||||
# Convert numpy array to torch tensor
|
# Convert numpy array to torch tensor
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
||||||
self.model.insert_audio(audio_tensor)
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
self.process_iter(is_last=True)
|
"""Handle speaker change event."""
|
||||||
self.model.refresh_segment(complete=True)
|
self.process_iter(is_last=True)
|
||||||
self.model.speaker = change_speaker.speaker
|
self.model.refresh_segment(complete=True)
|
||||||
self.global_time_offset = change_speaker.start
|
self.model.speaker = change_speaker.speaker
|
||||||
|
self.model.global_time_offset = change_speaker.start
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||||
@@ -112,15 +112,17 @@ class SimulStreamingOnlineProcessor:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
timestamped_words = self.model.infer(is_last=is_last)
|
timestamped_words = self.model.infer(is_last=is_last)
|
||||||
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
|
||||||
|
if not timestamped_words:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
|
||||||
self.buffer.extend(timestamped_words)
|
self.buffer.extend(timestamped_words)
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
self.committed.extend(timestamped_words)
|
self.committed.extend(timestamped_words)
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
return timestamped_words, self.end
|
return timestamped_words, self.end
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"SimulStreaming processing error: {e}")
|
logger.exception(f"SimulStreaming processing error: {e}")
|
||||||
return [], self.end
|
return [], self.end
|
||||||
@@ -136,12 +138,8 @@ class SimulStreamingOnlineProcessor:
|
|||||||
logger.exception(f"SimulStreaming warmup failed: {e}")
|
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# free the model and add a new model to stack.
|
|
||||||
# del self.model
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# self.asr.new_model_to_stack()
|
|
||||||
self.model.remove_hooks()
|
|
||||||
|
|
||||||
class SimulStreamingASR():
|
class SimulStreamingASR():
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
@@ -226,10 +224,7 @@ class SimulStreamingASR():
|
|||||||
self.tokenizer = self.set_translate_task()
|
self.tokenizer = self.set_translate_task()
|
||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.mlx_encoder, self.fw_encoder = None, None
|
self.mlx_encoder, self.fw_encoder = None, None
|
||||||
if self.encoder_backend == "mlx-whisper":
|
if self.encoder_backend == "mlx-whisper":
|
||||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||||
@@ -253,8 +248,7 @@ class SimulStreamingASR():
|
|||||||
device='auto',
|
device='auto',
|
||||||
compute_type='auto',
|
compute_type='auto',
|
||||||
)
|
)
|
||||||
|
self.shared_model = self.load_model()
|
||||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||||
@@ -303,11 +297,11 @@ class SimulStreamingASR():
|
|||||||
download_root=self.model_path,
|
download_root=self.model_path,
|
||||||
decoder_only=self.fast_encoder,
|
decoder_only=self.fast_encoder,
|
||||||
custom_alignment_heads=self.custom_alignment_heads
|
custom_alignment_heads=self.custom_alignment_heads
|
||||||
)
|
)
|
||||||
warmup_audio = load_file(self.warmup_file)
|
warmup_audio = load_file(self.warmup_file)
|
||||||
if warmup_audio is not None:
|
if warmup_audio is not None:
|
||||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||||
if self.fast_encoder:
|
if self.fast_encoder:
|
||||||
temp_model = AlignAtt(
|
temp_model = AlignAtt(
|
||||||
cfg=self.cfg,
|
cfg=self.cfg,
|
||||||
loaded_model=whisper_model,
|
loaded_model=whisper_model,
|
||||||
@@ -315,27 +309,9 @@ class SimulStreamingASR():
|
|||||||
fw_encoder=self.fw_encoder,
|
fw_encoder=self.fw_encoder,
|
||||||
)
|
)
|
||||||
temp_model.warmup(warmup_audio)
|
temp_model.warmup(warmup_audio)
|
||||||
temp_model.remove_hooks()
|
|
||||||
else:
|
else:
|
||||||
# For standard encoder, use the original transcribe warmup
|
|
||||||
warmup_audio = load_file(self.warmup_file)
|
|
||||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||||
return whisper_model
|
return whisper_model
|
||||||
|
|
||||||
def get_new_model_instance(self):
|
|
||||||
"""
|
|
||||||
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
|
||||||
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
|
||||||
"""
|
|
||||||
if len(self.models) == 0:
|
|
||||||
self.models.append(self.load_model())
|
|
||||||
new_model = self.models.pop()
|
|
||||||
return new_model
|
|
||||||
# self.models[0]
|
|
||||||
|
|
||||||
def new_model_to_stack(self):
|
|
||||||
self.models.append(self.load_model())
|
|
||||||
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
"""Set up translation task."""
|
"""Set up translation task."""
|
||||||
|
|||||||
@@ -1,17 +1,32 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||||
|
|
||||||
# extention of PyTorchInference for beam search
|
|
||||||
class BeamPyTorchInference(PyTorchInference):
|
|
||||||
|
|
||||||
def _kv_modules(self):
|
class BeamPyTorchInference(PyTorchInference):
|
||||||
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
"""Extension of PyTorchInference for beam search with cross-attention support."""
|
||||||
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
|
||||||
return key_modules + value_modules
|
def _kv_cache_ids(self):
|
||||||
|
"""Get cache_id strings for self-attention key/value modules."""
|
||||||
|
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
|
||||||
|
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
|
||||||
|
return key_ids + value_ids
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices):
|
def rearrange_kv_cache(self, source_indices):
|
||||||
if source_indices != list(range(len(source_indices))):
|
if source_indices != list(range(len(source_indices))):
|
||||||
for module_cache_id in self._kv_modules():
|
for cache_id in self._kv_cache_ids():
|
||||||
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
if cache_id in self.kv_cache:
|
||||||
from torch import Tensor
|
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
|
||||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
def logits(
|
||||||
|
self,
|
||||||
|
tokens: Tensor,
|
||||||
|
audio_features: Tensor,
|
||||||
|
return_cross_attn: bool = False,
|
||||||
|
):
|
||||||
|
"""Get logits, optionally returning cross-attention weights."""
|
||||||
|
return self.model.decoder(
|
||||||
|
tokens, audio_features,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
return_cross_attn=return_cross_attn,
|
||||||
|
)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AlignAttConfig():
|
class AlignAttConfig():
|
||||||
eval_data_path: str = "tmp"
|
eval_data_path: str = "tmp"
|
||||||
|
|||||||
80
whisperlivekit/simul_whisper/decoder_state.py
Normal file
80
whisperlivekit/simul_whisper/decoder_state.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecoderState:
|
||||||
|
|
||||||
|
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||||
|
|
||||||
|
tokenizer: Any = None
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
reset_tokenizer_to_auto_next_call: bool = False
|
||||||
|
|
||||||
|
tokens: List[torch.Tensor] = field(default_factory=list)
|
||||||
|
initial_tokens: Optional[torch.Tensor] = None
|
||||||
|
initial_token_length: int = 0
|
||||||
|
sot_index: int = 0
|
||||||
|
|
||||||
|
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||||
|
num_align_heads: int = 0
|
||||||
|
|
||||||
|
segments: List[torch.Tensor] = field(default_factory=list)
|
||||||
|
|
||||||
|
context: Any = None
|
||||||
|
|
||||||
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
global_time_offset: float = 0.0
|
||||||
|
cumulative_time_offset: float = 0.0
|
||||||
|
first_timestamp: Optional[float] = None
|
||||||
|
last_attend_frame: int = 0
|
||||||
|
|
||||||
|
speaker: int = -1
|
||||||
|
log_segments: int = 0
|
||||||
|
|
||||||
|
CIFLinear: Optional[torch.nn.Module] = None
|
||||||
|
always_fire: bool = False
|
||||||
|
never_fire: bool = False
|
||||||
|
|
||||||
|
suppress_tokens_fn: Any = None
|
||||||
|
|
||||||
|
token_decoder: Any = None
|
||||||
|
decoder_type: str = "greedy"
|
||||||
|
|
||||||
|
inference: Any = None
|
||||||
|
|
||||||
|
def clean_cache(self):
|
||||||
|
"""Clean the kv_cache after each inference step."""
|
||||||
|
self.kv_cache = {}
|
||||||
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
if self.token_decoder is not None:
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
def reset(self, rewind_threshold: int = 200):
|
||||||
|
"""
|
||||||
|
Reset transient state for a new segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewind_threshold: Value for resetting last_attend_frame
|
||||||
|
"""
|
||||||
|
self.last_attend_frame = -rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
|
"""
|
||||||
|
Full reset including audio segments and tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewind_threshold: Value for resetting last_attend_frame
|
||||||
|
"""
|
||||||
|
self.reset(rewind_threshold)
|
||||||
|
self.segments = []
|
||||||
|
self.tokens = []
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.first_timestamp = None
|
||||||
|
|
||||||
@@ -5,7 +5,6 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
from mlx_whisper import whisper
|
from mlx_whisper import whisper
|
||||||
|
|
||||||
mlx_model_mapping = {
|
mlx_model_mapping = {
|
||||||
|
|||||||
@@ -1,33 +1,36 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
|
||||||
from .config import AlignAttConfig
|
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
|
||||||
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
|
||||||
from whisperlivekit.whisper.timing import median_filter
|
|
||||||
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
|
||||||
from .beam import BeamPyTorchInference
|
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
|
||||||
import os
|
import os
|
||||||
from time import time
|
from time import time
|
||||||
from .token_buffer import TokenBuffer
|
from typing import List, Optional, Tuple
|
||||||
from whisperlivekit.backend_support import (
|
|
||||||
mlx_backend_available,
|
import numpy as np
|
||||||
faster_backend_available,
|
import torch
|
||||||
)
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from whisperlivekit.backend_support import (faster_backend_available,
|
||||||
|
mlx_backend_available)
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||||
|
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||||
|
TOKENS_PER_SECOND,
|
||||||
|
log_mel_spectrogram, pad_or_trim)
|
||||||
|
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
||||||
|
SuppressTokens)
|
||||||
|
from whisperlivekit.whisper.timing import median_filter
|
||||||
|
|
||||||
from ..timed_objects import PUNCTUATION_MARKS
|
from ..timed_objects import PUNCTUATION_MARKS
|
||||||
|
from .beam import BeamPyTorchInference
|
||||||
|
from .config import AlignAttConfig
|
||||||
|
from .decoder_state import DecoderState
|
||||||
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
DEC_PAD = 50257
|
DEC_PAD = 50257
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if mlx_backend_available():
|
if mlx_backend_available():
|
||||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
from mlx_whisper.audio import \
|
||||||
|
log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
|
||||||
if faster_backend_available():
|
if faster_backend_available():
|
||||||
@@ -52,6 +55,30 @@ def load_coreml_encoder():
|
|||||||
|
|
||||||
|
|
||||||
class AlignAtt:
|
class AlignAtt:
|
||||||
|
"""
|
||||||
|
Alignment-based Attention decoder for SimulStreaming.
|
||||||
|
|
||||||
|
This class is now hookless - the model can be shared across multiple
|
||||||
|
sessions, with each session maintaining its own DecoderState.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Property accessors for backward compatibility
|
||||||
|
@property
|
||||||
|
def speaker(self):
|
||||||
|
return self.state.speaker
|
||||||
|
|
||||||
|
@speaker.setter
|
||||||
|
def speaker(self, value):
|
||||||
|
self.state.speaker = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_time_offset(self):
|
||||||
|
return self.state.global_time_offset
|
||||||
|
|
||||||
|
@global_time_offset.setter
|
||||||
|
def global_time_offset(self, value):
|
||||||
|
self.state.global_time_offset = value
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg: AlignAttConfig,
|
cfg: AlignAttConfig,
|
||||||
@@ -59,8 +86,7 @@ class AlignAtt:
|
|||||||
mlx_encoder=None,
|
mlx_encoder=None,
|
||||||
fw_encoder=None,
|
fw_encoder=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log_segments = 0
|
# Shared model reference (can be shared across sessions)
|
||||||
|
|
||||||
self.model = loaded_model
|
self.model = loaded_model
|
||||||
self.mlx_encoder = mlx_encoder
|
self.mlx_encoder = mlx_encoder
|
||||||
self.fw_encoder = fw_encoder
|
self.fw_encoder = fw_encoder
|
||||||
@@ -74,119 +100,89 @@ class AlignAtt:
|
|||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
self.speaker = -1
|
|
||||||
self.decode_options = DecodingOptions(
|
self.decode_options = DecodingOptions(
|
||||||
language = cfg.language,
|
language=cfg.language,
|
||||||
without_timestamps = True,
|
without_timestamps=True,
|
||||||
task=cfg.task
|
task=cfg.task
|
||||||
)
|
)
|
||||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
|
||||||
# self.create_tokenizer('en')
|
|
||||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
|
||||||
self.global_time_offset = 0.0
|
|
||||||
self.reset_tokenizer_to_auto_next_call = False
|
|
||||||
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.l_hooks = []
|
|
||||||
|
|
||||||
# model to detect end-of-word boundary at the end of the segment
|
|
||||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
|
||||||
n_audio_state=self.model.dims.n_audio_state,
|
|
||||||
device=self.model.device)
|
|
||||||
|
|
||||||
# install hooks to access encoder-decoder attention
|
|
||||||
self.dec_attns = []
|
|
||||||
def layer_hook(module, net_input, net_output):
|
|
||||||
# net_output[1]: B*num_head*token_len*audio_len
|
|
||||||
t = F.softmax(net_output[1], dim=-1)
|
|
||||||
self.dec_attns.append(t.squeeze(0))
|
|
||||||
for b in self.model.decoder.blocks:
|
|
||||||
hook = b.cross_attn.register_forward_hook(layer_hook)
|
|
||||||
self.l_hooks.append(hook)
|
|
||||||
|
|
||||||
self.kv_cache = {}
|
|
||||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
|
||||||
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
|
||||||
# save as-is, for the first token or cross attention
|
|
||||||
self.kv_cache[module.cache_id] = net_output
|
|
||||||
else:
|
|
||||||
x = self.kv_cache[module.cache_id]
|
|
||||||
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
|
||||||
return self.kv_cache[module.cache_id]
|
|
||||||
|
|
||||||
for i,b in enumerate(self.model.decoder.blocks):
|
|
||||||
hooks = [
|
|
||||||
b.attn.key.register_forward_hook(kv_hook),
|
|
||||||
b.attn.value.register_forward_hook(kv_hook),
|
|
||||||
b.cross_attn.key.register_forward_hook(kv_hook),
|
|
||||||
b.cross_attn.value.register_forward_hook(kv_hook),
|
|
||||||
]
|
|
||||||
self.l_hooks.extend(hooks)
|
|
||||||
|
|
||||||
self.align_source = {}
|
|
||||||
self.num_align_heads = 0
|
|
||||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
|
||||||
layer_rank = layer_rank.item()
|
|
||||||
heads = self.align_source.get(layer_rank, [])
|
|
||||||
heads.append((self.num_align_heads, head_id.item()))
|
|
||||||
self.align_source[layer_rank] = heads
|
|
||||||
self.num_align_heads += 1
|
|
||||||
|
|
||||||
|
|
||||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
|
||||||
suppress_tokens = [
|
|
||||||
self.tokenizer.transcribe,
|
|
||||||
self.tokenizer.translate,
|
|
||||||
self.tokenizer.sot,
|
|
||||||
self.tokenizer.sot_prev,
|
|
||||||
self.tokenizer.sot_lm,
|
|
||||||
# self.tokenizer.eot
|
|
||||||
self.tokenizer.no_timestamps, # added by DM
|
|
||||||
] + list(self.tokenizer.all_language_tokens) # added by DM
|
|
||||||
if self.tokenizer.no_speech is not None:
|
|
||||||
suppress_tokens.append(self.tokenizer.no_speech)
|
|
||||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
|
||||||
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
|
||||||
sup_tokens = SuppressTokens(suppress_tokens)
|
|
||||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
|
||||||
# blank tokens are suppresed for new segments near the line 334
|
|
||||||
|
|
||||||
# it's going to be regenerated after lang id
|
|
||||||
self.segments = []
|
|
||||||
self.init_tokens()
|
|
||||||
|
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
|
||||||
self.cumulative_time_offset = 0.0
|
|
||||||
self.first_timestamp = None
|
|
||||||
|
|
||||||
if self.cfg.max_context_tokens is None:
|
if self.cfg.max_context_tokens is None:
|
||||||
self.max_context_tokens = self.max_text_len
|
self.max_context_tokens = self.max_text_len
|
||||||
else:
|
else:
|
||||||
self.max_context_tokens = self.cfg.max_context_tokens
|
self.max_context_tokens = self.cfg.max_context_tokens
|
||||||
|
|
||||||
|
# Initialize per-session state
|
||||||
|
self.state = DecoderState()
|
||||||
|
self._init_state(cfg)
|
||||||
|
|
||||||
|
def _init_state(self, cfg: AlignAttConfig):
|
||||||
|
"""Initialize the per-session decoder state."""
|
||||||
|
# Create tokenizer
|
||||||
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
|
|
||||||
|
# Timing state
|
||||||
|
self.state.global_time_offset = 0.0
|
||||||
|
self.state.last_attend_frame = -cfg.rewind_threshold
|
||||||
|
self.state.speaker = -1
|
||||||
|
|
||||||
|
# CIF helpers for end-of-word boundary detection
|
||||||
|
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
|
||||||
|
cfg,
|
||||||
|
n_audio_state=self.model.dims.n_audio_state,
|
||||||
|
device=self.model.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build alignment source mapping from model's alignment_heads
|
||||||
|
self.state.align_source = {}
|
||||||
|
self.state.num_align_heads = 0
|
||||||
|
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||||
|
layer_rank = layer_rank.item()
|
||||||
|
heads = self.state.align_source.get(layer_rank, [])
|
||||||
|
heads.append((self.state.num_align_heads, head_id.item()))
|
||||||
|
self.state.align_source[layer_rank] = heads
|
||||||
|
self.state.num_align_heads += 1
|
||||||
|
|
||||||
|
# Build suppress tokens function
|
||||||
|
suppress_tokens = [
|
||||||
|
self.tokenizer.transcribe,
|
||||||
|
self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot,
|
||||||
|
self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm,
|
||||||
|
self.tokenizer.no_timestamps,
|
||||||
|
] + list(self.tokenizer.all_language_tokens)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||||
|
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||||
|
sup_tokens = SuppressTokens(suppress_tokens)
|
||||||
|
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
|
||||||
|
|
||||||
|
# Initialize tokens
|
||||||
|
self.init_tokens()
|
||||||
self.init_context()
|
self.init_context()
|
||||||
|
|
||||||
# decoder type: greedy or beam
|
# Set up decoder type
|
||||||
|
self.state.decoder_type = cfg.decoder_type
|
||||||
if cfg.decoder_type == "greedy":
|
if cfg.decoder_type == "greedy":
|
||||||
logger.info("Using greedy decoder")
|
logger.info("Using greedy decoder")
|
||||||
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||||
self.decoder_type = "greedy"
|
|
||||||
|
|
||||||
elif cfg.decoder_type == "beam":
|
elif cfg.decoder_type == "beam":
|
||||||
self.decoder_type = "beam"
|
logger.info("Using beam decoder")
|
||||||
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
|
||||||
self.inference.kv_cache = self.kv_cache
|
self.state.inference.kv_cache = self.state.kv_cache
|
||||||
|
self.state.token_decoder = BeamSearchDecoder(
|
||||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
inference=self.state.inference,
|
||||||
|
eot=self.tokenizer.eot,
|
||||||
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
beam_size=cfg.beam_size
|
||||||
self.pending_incomplete_tokens = []
|
)
|
||||||
|
|
||||||
def remove_hooks(self):
|
|
||||||
for hook in self.l_hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
def warmup(self, audio):
|
def warmup(self, audio):
|
||||||
try:
|
try:
|
||||||
@@ -204,96 +200,100 @@ class AlignAtt:
|
|||||||
num_languages=self.model.num_languages,
|
num_languages=self.model.num_languages,
|
||||||
task=self.decode_options.task
|
task=self.decode_options.task
|
||||||
)
|
)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
|
||||||
def init_context(self):
|
def init_context(self):
|
||||||
kw = {'tokenizer': self.tokenizer,
|
kw = {'tokenizer': self.tokenizer,
|
||||||
'device': self.model.device,
|
'device': self.model.device,
|
||||||
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
||||||
self.context = TokenBuffer.empty(**kw)
|
self.state.context = TokenBuffer.empty(**kw)
|
||||||
if self.cfg.static_init_prompt is not None:
|
if self.cfg.static_init_prompt is not None:
|
||||||
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||||
if self.cfg.init_prompt is not None:
|
if self.cfg.init_prompt is not None:
|
||||||
self.context.text += self.cfg.init_prompt
|
self.state.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
def init_tokens(self):
|
def init_tokens(self):
|
||||||
logger.debug(f"init tokens, {len(self.segments)}")
|
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||||
# init tokens (mandatory prompt)
|
# init tokens (mandatory prompt)
|
||||||
self.initial_tokens = torch.tensor(
|
self.state.initial_tokens = torch.tensor(
|
||||||
self.tokenizer.sot_sequence_including_notimestamps,
|
self.tokenizer.sot_sequence_including_notimestamps,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.model.device).unsqueeze(0)
|
device=self.model.device).unsqueeze(0)
|
||||||
self.initial_token_length = self.initial_tokens.shape[1]
|
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
# self.segments = []
|
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||||
logger.debug(f"init tokens after, {len(self.segments)}")
|
self.state.tokens = [self.state.initial_tokens]
|
||||||
self.tokens = [self.initial_tokens]
|
|
||||||
|
|
||||||
def trim_context(self):
|
def trim_context(self):
|
||||||
logger.info("Trimming context")
|
logger.info("Trimming context")
|
||||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
||||||
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
logger.info(f"Context text: {self.state.context.as_text()}")
|
||||||
logger.info(f"Context text: {self.context.as_text()}")
|
l = sum(t.shape[1] for t in self.state.tokens) + c
|
||||||
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
|
||||||
l = sum(t.shape[1] for t in self.tokens) + c
|
|
||||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
|
||||||
if self.cfg.static_init_prompt is None:
|
if self.cfg.static_init_prompt is None:
|
||||||
after = 0
|
after = 0
|
||||||
else:
|
else:
|
||||||
after = len(self.cfg.static_init_prompt)
|
after = len(self.cfg.static_init_prompt)
|
||||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
|
||||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||||
t = self.context.trim_words(after=after)
|
t = self.state.context.trim_words(after=after)
|
||||||
l -= t
|
l -= t
|
||||||
c -= t
|
c -= t
|
||||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
if t == 0:
|
if t == 0:
|
||||||
break
|
break
|
||||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
||||||
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
|
||||||
|
|
||||||
|
|
||||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
def logits(
|
||||||
if self.cfg.decoder_type == "greedy":
|
self,
|
||||||
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
tokens: torch.Tensor,
|
||||||
|
audio_features: torch.Tensor,
|
||||||
|
return_cross_attn: bool = False
|
||||||
|
):
|
||||||
|
"""Get logits from decoder, optionally returning cross-attention weights."""
|
||||||
|
if self.state.decoder_type == "greedy":
|
||||||
|
return self.model.decoder(
|
||||||
|
tokens, audio_features,
|
||||||
|
kv_cache=self.state.kv_cache,
|
||||||
|
return_cross_attn=return_cross_attn
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Logits shape: {tokens.shape}")
|
logger.debug(f"Logits shape: {tokens.shape}")
|
||||||
logit = self.inference.logits(tokens, audio_features)
|
return self.state.inference.logits(
|
||||||
return logit
|
tokens, audio_features,
|
||||||
|
return_cross_attn=return_cross_attn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def refresh_segment(self, complete=False):
|
def refresh_segment(self, complete=False):
|
||||||
|
|
||||||
logger.debug("Refreshing segment:")
|
logger.debug("Refreshing segment:")
|
||||||
self.init_tokens()
|
self.init_tokens()
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
# self.detected_language = None
|
self.state.cumulative_time_offset = 0.0
|
||||||
self.cumulative_time_offset = 0.0
|
|
||||||
self.init_context()
|
self.init_context()
|
||||||
logger.debug(f"Context: {self.context}")
|
logger.debug(f"Context: {self.state.context}")
|
||||||
if not complete and len(self.segments) > 2:
|
if not complete and len(self.state.segments) > 2:
|
||||||
self.segments = self.segments[-2:]
|
self.state.segments = self.state.segments[-2:]
|
||||||
else:
|
else:
|
||||||
logger.debug("removing all segments.")
|
logger.debug("removing all segments.")
|
||||||
self.segments = []
|
self.state.segments = []
|
||||||
self.log_segments += 1
|
self.state.log_segments += 1
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
self.pending_incomplete_tokens = []
|
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
if self.always_fire: return True
|
if self.state.always_fire:
|
||||||
if self.never_fire: return False
|
return True
|
||||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
if self.state.never_fire:
|
||||||
|
return False
|
||||||
|
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
|
||||||
|
|
||||||
def _current_tokens(self):
|
def _current_tokens(self):
|
||||||
|
toks = self.state.tokens
|
||||||
toks = self.tokens
|
|
||||||
# very first infer: duplicate start of seq to beam_size
|
# very first infer: duplicate start of seq to beam_size
|
||||||
if toks[0].shape[0] == 1:
|
if toks[0].shape[0] == 1:
|
||||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
|
||||||
|
|
||||||
if not self.context.is_empty():
|
if not self.state.context.is_empty():
|
||||||
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||||
toks = [context_toks] + toks
|
toks = [context_toks] + toks
|
||||||
|
|
||||||
# make it one tensor
|
# make it one tensor
|
||||||
@@ -313,7 +313,7 @@ class AlignAtt:
|
|||||||
### audio buffer
|
### audio buffer
|
||||||
|
|
||||||
def segments_len(self):
|
def segments_len(self):
|
||||||
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
|
||||||
return segments_len
|
return segments_len
|
||||||
|
|
||||||
def _apply_minseglen(self):
|
def _apply_minseglen(self):
|
||||||
@@ -326,42 +326,36 @@ class AlignAtt:
|
|||||||
|
|
||||||
def insert_audio(self, segment=None):
|
def insert_audio(self, segment=None):
|
||||||
if segment is not None:
|
if segment is not None:
|
||||||
self.segments.append(segment)
|
self.state.segments.append(segment)
|
||||||
|
|
||||||
removed_len = 0
|
removed_len = 0
|
||||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||||
segments_len = self.segments_len()
|
segments_len = self.segments_len()
|
||||||
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
removed_len = self.segments[0].shape[0] / 16000
|
removed_len = self.state.segments[0].shape[0] / 16000
|
||||||
segments_len -= removed_len
|
segments_len -= removed_len
|
||||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||||
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
self.state.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||||
self.segments = self.segments[1:]
|
self.state.segments = self.state.segments[1:]
|
||||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
|
||||||
if len(self.tokens) > 1:
|
if len(self.state.tokens) > 1:
|
||||||
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
|
||||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||||
return removed_len
|
return removed_len
|
||||||
|
|
||||||
def _clean_cache(self):
|
def _clean_cache(self):
|
||||||
'''clean the cache that stores the attention matrices and kv_cache.
|
"""Clean the kv_cache after each inference step."""
|
||||||
It must be called every time after generation with the model.'''
|
self.state.clean_cache()
|
||||||
# cleaning cache
|
|
||||||
self.dec_attns = []
|
|
||||||
self.kv_cache = {}
|
|
||||||
if self.decoder_type == "beam":
|
|
||||||
self.inference.kv_cache = self.kv_cache
|
|
||||||
self.token_decoder.reset()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def lang_id(self, encoder_features):
|
def lang_id(self, encoder_features):
|
||||||
"""Language detection from encoder features.
|
"""Language detection from encoder features.
|
||||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
This code is trimmed and copy-pasted from whisper.decoding.detect_language.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# forward pass using a single token, startoftranscript
|
# forward pass using a single token, startoftranscript
|
||||||
n_audio = encoder_features.shape[0]
|
n_audio = encoder_features.shape[0]
|
||||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||||
|
# Note: don't use kv_cache for language detection
|
||||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||||
|
|
||||||
# collect detected languages; suppress all non-language tokens
|
# collect detected languages; suppress all non-language tokens
|
||||||
@@ -391,19 +385,19 @@ class AlignAtt:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def infer(self, is_last=False):
|
def infer(self, is_last=False):
|
||||||
new_segment = True
|
new_segment = True
|
||||||
if len(self.segments) == 0:
|
if len(self.state.segments) == 0:
|
||||||
logger.debug("No segments, nothing to do")
|
logger.debug("No segments, nothing to do")
|
||||||
return []
|
return []
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
input_segments = torch.cat(self.segments, dim=0)
|
input_segments = torch.cat(self.state.segments, dim=0)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
if len(self.segments) > 1:
|
if len(self.state.segments) > 1:
|
||||||
input_segments = torch.cat(self.segments, dim=0)
|
input_segments = torch.cat(self.state.segments, dim=0)
|
||||||
else:
|
else:
|
||||||
input_segments = self.segments[0]
|
input_segments = self.state.segments[0]
|
||||||
|
|
||||||
beg_encode = time()
|
beg_encode = time()
|
||||||
if self.use_mlcore:
|
if self.use_mlcore:
|
||||||
@@ -457,18 +451,18 @@ class AlignAtt:
|
|||||||
end_encode = time()
|
end_encode = time()
|
||||||
# print('Encoder duration:', end_encode-beg_encode)
|
# print('Encoder duration:', end_encode-beg_encode)
|
||||||
|
|
||||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
|
||||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
||||||
if seconds_since_start >= 2.0:
|
if seconds_since_start >= 2.0:
|
||||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
self.create_tokenizer(top_lan)
|
self.create_tokenizer(top_lan)
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
self.cumulative_time_offset = 0.0
|
self.state.cumulative_time_offset = 0.0
|
||||||
self.init_tokens()
|
self.init_tokens()
|
||||||
self.init_context()
|
self.init_context()
|
||||||
self.detected_language = top_lan
|
self.state.detected_language = top_lan
|
||||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
|
|
||||||
self.trim_context()
|
self.trim_context()
|
||||||
@@ -488,92 +482,80 @@ class AlignAtt:
|
|||||||
|
|
||||||
l_absolute_timestamps = []
|
l_absolute_timestamps = []
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
|
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||||
|
|
||||||
if new_segment:
|
if new_segment:
|
||||||
tokens_for_logits = current_tokens
|
tokens_for_logits = current_tokens
|
||||||
else:
|
else:
|
||||||
# only need to use the last token except in the first forward pass
|
# only need to use the last token except in the first forward pass
|
||||||
tokens_for_logits = current_tokens[:,-1:]
|
tokens_for_logits = current_tokens[:, -1:]
|
||||||
|
|
||||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
# Get logits and cross-attention weights from decoder
|
||||||
|
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||||
|
logits, cross_attns = result
|
||||||
|
|
||||||
|
# Accumulate cross-attention from this forward pass
|
||||||
|
accumulated_cross_attns.append(cross_attns)
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
logger.info("no speech, stop")
|
logger.info("no speech, stop")
|
||||||
break
|
break
|
||||||
|
|
||||||
logits = logits[:, -1, :] # logits for the last token
|
logits = logits[:, -1, :] # logits for the last token
|
||||||
|
|
||||||
# supress blank tokens only at the beginning of the segment
|
# suppress blank tokens only at the beginning of the segment
|
||||||
if new_segment:
|
if new_segment:
|
||||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
new_segment = False
|
new_segment = False
|
||||||
self.suppress_tokens(logits)
|
self.state.suppress_tokens_fn(logits)
|
||||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||||
self.debug_print_tokens(current_tokens)
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
# Process accumulated cross-attention weights for alignment
|
||||||
for i, attn_mat in enumerate(self.dec_attns):
|
attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
|
||||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
|
||||||
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
|
||||||
if len(align_heads_in_layer) == 0:
|
|
||||||
continue
|
|
||||||
for align_head_rank, head_id in align_heads_in_layer:
|
|
||||||
if self.cfg.beam_size == 1:
|
|
||||||
a = attn_mat[head_id, :, :]
|
|
||||||
a = a.unsqueeze(0)
|
|
||||||
else:
|
|
||||||
a = attn_mat[:, head_id, :, :]
|
|
||||||
attn_of_alignment_heads[align_head_rank].append(a)
|
|
||||||
tmp = []
|
|
||||||
for mat in attn_of_alignment_heads:
|
|
||||||
t = torch.cat(mat, dim=1)
|
|
||||||
tmp.append(t)
|
|
||||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
|
||||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
|
||||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
|
||||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
|
||||||
|
|
||||||
# for each beam, the most attended frame is:
|
# for each beam, the most attended frame is:
|
||||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
|
||||||
|
|
||||||
# Calculate absolute timestamps accounting for cumulative offset
|
# Calculate absolute timestamps accounting for cumulative offset
|
||||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
absolute_timestamps = [
|
||||||
|
(frame * 0.02 + self.state.cumulative_time_offset)
|
||||||
|
for frame in most_attended_frames.tolist()
|
||||||
|
]
|
||||||
|
|
||||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
|
||||||
|
|
||||||
most_attended_frame = most_attended_frames[0].item()
|
most_attended_frame = most_attended_frames[0].item()
|
||||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||||
|
|
||||||
logger.debug("current tokens" + str(current_tokens.shape))
|
logger.debug("current tokens" + str(current_tokens.shape))
|
||||||
if completed:
|
if completed:
|
||||||
# # stripping the last token, the eot
|
# stripping the last token, the eot
|
||||||
current_tokens = current_tokens[:, :-1]
|
current_tokens = current_tokens[:, :-1]
|
||||||
break
|
break
|
||||||
|
|
||||||
# for some rare cases where the attention fails
|
# for some rare cases where the attention fails
|
||||||
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||||
# TODO: check this
|
|
||||||
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
||||||
logger.debug("ommit rewinding from special tokens")
|
logger.debug("omit rewinding from special tokens")
|
||||||
self.last_attend_frame = most_attended_frame
|
self.state.last_attend_frame = most_attended_frame
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
||||||
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
f"last attention pos: {self.state.last_attend_frame}; omit this segment")
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.last_attend_frame = most_attended_frame
|
self.state.last_attend_frame = most_attended_frame
|
||||||
|
|
||||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||||
@@ -593,12 +575,12 @@ class AlignAtt:
|
|||||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||||
|
|
||||||
# Prepend pending tokens from previous chunk if any
|
# Prepend pending tokens from previous chunk if any
|
||||||
if self.pending_incomplete_tokens:
|
if self.state.pending_incomplete_tokens:
|
||||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
|
||||||
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||||
|
|
||||||
if fire_detected or is_last: #or punctuation_stop:
|
if fire_detected or is_last:
|
||||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
else:
|
else:
|
||||||
@@ -609,20 +591,18 @@ class AlignAtt:
|
|||||||
else:
|
else:
|
||||||
new_hypothesis = []
|
new_hypothesis = []
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.tokens.append(new_tokens)
|
self.state.tokens.append(new_tokens)
|
||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
|
|
||||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
||||||
self.first_timestamp = l_absolute_timestamps[0]
|
self.state.first_timestamp = l_absolute_timestamps[0]
|
||||||
|
|
||||||
|
|
||||||
timestamped_words = []
|
timestamped_words = []
|
||||||
timestamp_idx = 0
|
timestamp_idx = 0
|
||||||
@@ -641,20 +621,85 @@ class AlignAtt:
|
|||||||
timestamp_idx += len(word_tokens)
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
timestamp_entry = ASRToken(
|
||||||
start=round(current_timestamp, 2),
|
start=round(current_timestamp, 2),
|
||||||
end=round(current_timestamp + 0.1, 2),
|
end=round(current_timestamp + 0.1, 2),
|
||||||
text= word,
|
text=word,
|
||||||
speaker=self.speaker,
|
speaker=self.state.speaker,
|
||||||
detected_language=self.detected_language
|
detected_language=self.state.detected_language
|
||||||
).with_offset(
|
).with_offset(
|
||||||
self.global_time_offset
|
self.state.global_time_offset
|
||||||
)
|
)
|
||||||
timestamped_words.append(timestamp_entry)
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
# Hold incomplete tokens for next chunk
|
# Hold incomplete tokens for next chunk
|
||||||
self.pending_incomplete_tokens = []
|
self.state.pending_incomplete_tokens = []
|
||||||
if split_words and replacement_char in split_words[-1]:
|
if split_words and replacement_char in split_words[-1]:
|
||||||
self.pending_incomplete_tokens = split_tokens[-1]
|
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}")
|
||||||
|
|
||||||
return timestamped_words
|
return timestamped_words
|
||||||
|
|
||||||
|
def _process_cross_attention(
|
||||||
|
self,
|
||||||
|
cross_attns: List[torch.Tensor],
|
||||||
|
content_mel_len: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Process cross-attention weights from decoder layers for alignment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_attns: List of cross-attention tensors from each decoder layer.
|
||||||
|
Each tensor has shape (batch, n_head, seq_len, audio_len)
|
||||||
|
content_mel_len: Length of actual audio content in mel frames
|
||||||
|
|
||||||
|
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
|
||||||
|
"""
|
||||||
|
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||||
|
num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
|
|
||||||
|
if cross_attns and isinstance(cross_attns[0], list):
|
||||||
|
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
|
||||||
|
else:
|
||||||
|
flattened_attns = cross_attns
|
||||||
|
|
||||||
|
for idx, attn_mat in enumerate(flattened_attns):
|
||||||
|
layer_rank = idx % num_decoder_layers
|
||||||
|
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
|
||||||
|
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||||
|
if len(align_heads_in_layer) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
attn_mat = F.softmax(attn_mat, dim=-1)
|
||||||
|
|
||||||
|
for align_head_rank, head_id in align_heads_in_layer:
|
||||||
|
if self.cfg.beam_size == 1:
|
||||||
|
# (n_head, seq_len, audio_len) when squeezed
|
||||||
|
if attn_mat.dim() == 4:
|
||||||
|
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
|
||||||
|
else:
|
||||||
|
a = attn_mat[head_id, :, :]
|
||||||
|
a = a.unsqueeze(0) # (1, seq_len, audio_len)
|
||||||
|
else:
|
||||||
|
# attn_mat: (batch, n_head, seq_len, audio_len)
|
||||||
|
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
|
||||||
|
attn_of_alignment_heads[align_head_rank].append(a)
|
||||||
|
|
||||||
|
tmp = []
|
||||||
|
for mat in attn_of_alignment_heads:
|
||||||
|
if mat:
|
||||||
|
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
|
||||||
|
tmp.append(t)
|
||||||
|
|
||||||
|
if not tmp:
|
||||||
|
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
|
||||||
|
|
||||||
|
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
|
||||||
|
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||||
|
|
||||||
|
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||||
|
return attn_of_alignment_heads
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import torch
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TokenBuffer:
|
class TokenBuffer:
|
||||||
|
|
||||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List, Union, Dict, Any
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import Optional, List, Tuple, Union, Any
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
|
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
|
||||||
|
SilentLine, SpeakerSegment,
|
||||||
|
TimedText)
|
||||||
|
|
||||||
|
|
||||||
class TokensAlignment:
|
class TokensAlignment:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5):
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
if warmup_file == "":
|
if warmup_file == "":
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
|
||||||
import importlib.resources as resources
|
|
||||||
import base64
|
import base64
|
||||||
|
import importlib.resources as resources
|
||||||
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -96,11 +96,13 @@ def get_inline_ui_html():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
import uvicorn
|
|
||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
import pathlib
|
|
||||||
import whisperlivekit.web as webpkg
|
import whisperlivekit.web as webpkg
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|||||||
@@ -4,15 +4,17 @@ import json
|
|||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
from pathlib import Path
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
|
||||||
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
pad_or_trim)
|
||||||
|
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||||
|
decode, detect_language)
|
||||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||||
from whisperlivekit.whisper.transcribe import transcribe
|
from whisperlivekit.whisper.transcribe import transcribe
|
||||||
from whisperlivekit.whisper.version import __version__
|
from whisperlivekit.whisper.version import __version__
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||||
|
Tuple, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -146,16 +147,13 @@ class PyTorchInference(Inference):
|
|||||||
self.model: "Whisper" = model
|
self.model: "Whisper" = model
|
||||||
self.initial_token_length = initial_token_length
|
self.initial_token_length = initial_token_length
|
||||||
self.kv_cache = {}
|
self.kv_cache = {}
|
||||||
self.hooks = []
|
|
||||||
|
|
||||||
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
self.kv_cache_ids = []
|
||||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
for block in self.model.decoder.blocks:
|
||||||
self.kv_modules = key_modules + value_modules
|
self.kv_cache_ids.append(block.attn.key_cache_id)
|
||||||
|
self.kv_cache_ids.append(block.attn.value_cache_id)
|
||||||
|
|
||||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
if not self.kv_cache:
|
|
||||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
|
||||||
|
|
||||||
if tokens.shape[-1] > self.initial_token_length:
|
if tokens.shape[-1] > self.initial_token_length:
|
||||||
# only need to use the last token except in the first forward pass
|
# only need to use the last token except in the first forward pass
|
||||||
tokens = tokens[:, -1:]
|
tokens = tokens[:, -1:]
|
||||||
@@ -163,17 +161,14 @@ class PyTorchInference(Inference):
|
|||||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||||
|
|
||||||
def cleanup_caching(self):
|
def cleanup_caching(self):
|
||||||
for hook in self.hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
self.kv_cache = {}
|
self.kv_cache = {}
|
||||||
self.hooks = []
|
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices):
|
def rearrange_kv_cache(self, source_indices):
|
||||||
if source_indices != list(range(len(source_indices))):
|
if source_indices != list(range(len(source_indices))):
|
||||||
for module in self.kv_modules:
|
for cache_id in self.kv_cache_ids:
|
||||||
# update the key/value cache to contain the selected sequences
|
if cache_id in self.kv_cache:
|
||||||
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
# update the key/value cache to contain the selected sequences
|
||||||
|
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||||
|
|
||||||
|
|
||||||
class SequenceRanker:
|
class SequenceRanker:
|
||||||
|
|||||||
@@ -79,18 +79,23 @@ def disable_sdpa():
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
use_sdpa = False # Disable SDPA to ensure qk is always computed when needed
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
|
self.n_text_ctx = n_text_ctx
|
||||||
self.query = Linear(n_state, n_state)
|
self.query = Linear(n_state, n_state)
|
||||||
self.key = Linear(n_state, n_state, bias=False)
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
self.value = Linear(n_state, n_state)
|
self.value = Linear(n_state, n_state)
|
||||||
self.out = Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
self.cache_id = cache_id
|
self.cache_id = cache_id
|
||||||
self.key.cache_id = f"{cache_id}_key"
|
# Cache IDs for key and value (used with dict-based kv_cache)
|
||||||
self.value.cache_id = f"{cache_id}_value"
|
self.key_cache_id = f"{cache_id}_key"
|
||||||
|
self.value_cache_id = f"{cache_id}_value"
|
||||||
|
# Keep these for backward compatibility with hook-based caching
|
||||||
|
self.key.cache_id = self.key_cache_id
|
||||||
|
self.value.cache_id = self.value_cache_id
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -101,19 +106,45 @@ class MultiHeadAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
|
|
||||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
if xa is None:
|
||||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
# Self-attention
|
||||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
k = self.key(x)
|
||||||
k = self.key(x if xa is None else xa)
|
v = self.value(x)
|
||||||
v = self.value(x if xa is None else xa)
|
if kv_cache is not None:
|
||||||
|
k, v = self._update_self_attn_cache(k, v, kv_cache)
|
||||||
else:
|
else:
|
||||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
# Cross-attention: compute once and cache, or reuse from cache
|
||||||
k = kv_cache[self.key]
|
if kv_cache is not None and self.key_cache_id in kv_cache:
|
||||||
v = kv_cache[self.value]
|
k = kv_cache[self.key_cache_id]
|
||||||
|
v = kv_cache[self.value_cache_id]
|
||||||
|
else:
|
||||||
|
k = self.key(xa)
|
||||||
|
v = self.value(xa)
|
||||||
|
if kv_cache is not None:
|
||||||
|
kv_cache[self.key_cache_id] = k
|
||||||
|
kv_cache[self.value_cache_id] = v
|
||||||
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv), qk
|
return self.out(wv), qk
|
||||||
|
|
||||||
|
def _update_self_attn_cache(
|
||||||
|
self, k: Tensor, v: Tensor, kv_cache: dict
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Update self-attention kv cache by concatenating new k,v with cached values."""
|
||||||
|
if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx:
|
||||||
|
# First token or context overflow: save as-is
|
||||||
|
kv_cache[self.key_cache_id] = k.detach()
|
||||||
|
kv_cache[self.value_cache_id] = v.detach()
|
||||||
|
else:
|
||||||
|
# Concatenate with existing cache
|
||||||
|
cached_k = kv_cache[self.key_cache_id]
|
||||||
|
cached_v = kv_cache[self.value_cache_id]
|
||||||
|
k = torch.cat([cached_k, k], dim=1).detach()
|
||||||
|
v = torch.cat([cached_v, v], dim=1).detach()
|
||||||
|
kv_cache[self.key_cache_id] = k
|
||||||
|
kv_cache[self.value_cache_id] = v
|
||||||
|
return k, v
|
||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@@ -143,14 +174,21 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
def __init__(
|
||||||
|
self, n_state: int, n_head: int, cross_attention: bool = False,
|
||||||
|
cache_id: str = "", n_text_ctx: int = 448
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
self.attn = MultiHeadAttention(
|
||||||
|
n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx
|
||||||
|
)
|
||||||
self.attn_ln = LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
self.cross_attn = (
|
self.cross_attn = (
|
||||||
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
MultiHeadAttention(
|
||||||
|
n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx
|
||||||
|
) if cross_attention else None
|
||||||
)
|
)
|
||||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
@@ -166,12 +204,21 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
xa: Optional[Tensor] = None,
|
xa: Optional[Tensor] = None,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
x: The output tensor
|
||||||
|
cross_attn_qk: Cross-attention weights (if cross_attn exists), else None
|
||||||
|
"""
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
|
cross_attn_qk = None
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
cross_out, cross_attn_qk = self.cross_attn(
|
||||||
|
self.cross_attn_ln(x), xa, kv_cache=kv_cache
|
||||||
|
)
|
||||||
|
x = x + cross_out
|
||||||
x = x + self.mlp(self.mlp_ln(x))
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
return x
|
return x, cross_attn_qk
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
@@ -201,7 +248,7 @@ class AudioEncoder(nn.Module):
|
|||||||
x = (x + self.positional_embedding).to(x.dtype)
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x, _ = block(x) # Encoder blocks don't have cross-attention
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
return x
|
return x
|
||||||
@@ -212,13 +259,17 @@ class TextDecoder(nn.Module):
|
|||||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
|
||||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||||
|
|
||||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
ResidualAttentionBlock(
|
||||||
|
n_state, n_head, cross_attention=True,
|
||||||
|
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
|
||||||
|
)
|
||||||
for i in range(n_layer)
|
for i in range(n_layer)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -227,28 +278,57 @@ class TextDecoder(nn.Module):
|
|||||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
self.register_buffer("mask", mask, persistent=False)
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Tensor,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
return_cross_attn: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
the text tokens
|
the text tokens
|
||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
the encoded audio features to be attended on
|
the encoded audio features to be attended on
|
||||||
|
kv_cache : Optional[dict]
|
||||||
|
Dictionary to store/retrieve key-value cache for efficient decoding
|
||||||
|
return_cross_attn : bool
|
||||||
|
If True, return cross-attention weights from all decoder layers
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
logits : Tensor
|
||||||
|
The output logits
|
||||||
|
cross_attns : Optional[List[Tensor]]
|
||||||
|
List of cross-attention weights per layer (only if return_cross_attn=True)
|
||||||
"""
|
"""
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
# Calculate offset from self-attention cache (not cross-attention which has audio length)
|
||||||
|
offset = 0
|
||||||
|
if kv_cache:
|
||||||
|
# Use the first decoder block's self-attention key cache to get token position
|
||||||
|
first_self_attn_key = self.blocks[0].attn.key_cache_id
|
||||||
|
if first_self_attn_key in kv_cache:
|
||||||
|
offset = kv_cache[first_self_attn_key].shape[1]
|
||||||
|
|
||||||
x = (
|
x = (
|
||||||
self.token_embedding(x)
|
self.token_embedding(x)
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
)
|
)
|
||||||
x = x.to(xa.dtype)
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
|
cross_attns = [] if return_cross_attn else None
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
if return_cross_attn and cross_attn_qk is not None:
|
||||||
|
cross_attns.append(cross_attn_qk)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = (
|
logits = (
|
||||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
).float()
|
).float()
|
||||||
|
|
||||||
|
if return_cross_attn:
|
||||||
|
return logits, cross_attns
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@@ -292,8 +372,18 @@ class Whisper(nn.Module):
|
|||||||
def embed_audio(self, mel: torch.Tensor):
|
def embed_audio(self, mel: torch.Tensor):
|
||||||
return self.encoder(mel)
|
return self.encoder(mel)
|
||||||
|
|
||||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
def logits(
|
||||||
return self.decoder(tokens, audio_features)
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
audio_features: torch.Tensor,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
return_cross_attn: bool = False,
|
||||||
|
):
|
||||||
|
return self.decoder(
|
||||||
|
tokens, audio_features,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
return_cross_attn=return_cross_attn
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
@@ -312,39 +402,6 @@ class Whisper(nn.Module):
|
|||||||
def num_languages(self):
|
def num_languages(self):
|
||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
||||||
"""
|
|
||||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
||||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
||||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
|
||||||
intermediate tensors to be reused during later calculations.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
cache : Dict[nn.Module, torch.Tensor]
|
|
||||||
A dictionary object mapping the key/value projection modules to its cache
|
|
||||||
hooks : List[RemovableHandle]
|
|
||||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
||||||
"""
|
|
||||||
cache = {**cache} if cache is not None else {}
|
|
||||||
hooks = []
|
|
||||||
|
|
||||||
def save_to_cache(module, _, output):
|
|
||||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
|
||||||
# save as-is, for the first token or cross attention
|
|
||||||
cache[module] = output
|
|
||||||
else:
|
|
||||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
|
||||||
return cache[module]
|
|
||||||
|
|
||||||
def install_hooks(layer: nn.Module):
|
|
||||||
if isinstance(layer, MultiHeadAttention):
|
|
||||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
||||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
||||||
|
|
||||||
self.decoder.apply(install_hooks)
|
|
||||||
return cache, hooks
|
|
||||||
|
|
||||||
detect_language = detect_language_function
|
detect_language = detect_language_function
|
||||||
transcribe = transcribe_function
|
transcribe = transcribe_function
|
||||||
decode = decode_function
|
decode = decode_function
|
||||||
|
|||||||
@@ -8,28 +8,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from .audio import (
|
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
|
||||||
FRAMES_PER_SECOND,
|
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
|
||||||
HOP_LENGTH,
|
|
||||||
N_FRAMES,
|
|
||||||
N_SAMPLES,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
log_mel_spectrogram,
|
|
||||||
pad_or_trim,
|
|
||||||
)
|
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (exact_div, format_timestamp, get_end, get_writer,
|
||||||
exact_div,
|
make_safe, optional_float, optional_int, str2bool)
|
||||||
format_timestamp,
|
|
||||||
get_end,
|
|
||||||
get_writer,
|
|
||||||
make_safe,
|
|
||||||
optional_float,
|
|
||||||
optional_int,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|||||||
Reference in New Issue
Block a user