mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c62f8f84 | ||
|
|
7faa21f95f | ||
|
|
4e9f951551 | ||
|
|
870141298c | ||
|
|
872faa422a | ||
|
|
fc9cb66813 | ||
|
|
a175d1a327 | ||
|
|
6206fff118 | ||
|
|
b5067249c0 | ||
|
|
f4f9831d39 | ||
|
|
254faaf64c | ||
|
|
8e7aea4fcf | ||
|
|
270faf2069 | ||
|
|
b7c1cc77cc | ||
|
|
9a45ec221c | ||
|
|
3e13ee6fc3 | ||
|
|
b7d20a0ff0 | ||
|
|
c1bb9c2bde | ||
|
|
11e9def0b2 | ||
|
|
3104f40f6e | ||
|
|
e9b4ceeee5 | ||
|
|
437641fb43 | ||
|
|
bfd60b3921 | ||
|
|
1e67bf97f0 | ||
|
|
bbd4fd6cff | ||
|
|
28985962a0 | ||
|
|
a38c103fcd | ||
|
|
4d2ffb24f8 | ||
|
|
1bbbb7903c |
21
README.md
21
README.md
@@ -51,9 +51,11 @@ pip install whisperlivekit
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
> - See [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - Check the [troubleshooting guide](docs/troubleshooting.md) for step-by-step fixes collected from recent GPU setup/env issues.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
@@ -96,11 +98,13 @@ wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
from whisperlivekit import AudioProcessor, TranscriptionEngine, parse_args
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@@ -141,13 +145,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--no-vac` | Disable Voice Activity Controller. NOT ADVISED | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection. NOT ADVISED | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--host` | Server host address | `localhost` |
|
||||
| `--port` | Server port | `8000` |
|
||||
@@ -183,7 +187,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
|
||||
|
||||
258
ReadmeJP.md
258
ReadmeJP.md
@@ -1,258 +0,0 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
||||
|
||||
#### 主要な研究による技術:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
||||
|
||||
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
||||
|
||||
### アーキテクチャ
|
||||
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
||||
|
||||
### インストールとクイックスタート
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
||||
>
|
||||
> | OS | インストール方法 |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
||||
|
||||
#### クイックスタート
|
||||
1. **文字起こしサーバーを起動します:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
```
|
||||
|
||||
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
||||
|
||||
|
||||
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
||||
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
||||
|
||||
#### オプションの依存関係
|
||||
|
||||
| オプション | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Diartによる話者ダイアライゼーション | `diart` |
|
||||
| オリジナルのWhisperバックエンド | `whisper` |
|
||||
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
||||
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
||||
| OpenAI APIバックエンド | `openai` |
|
||||
|
||||
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
||||
|
||||
### 使用例
|
||||
|
||||
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
||||
|
||||
```bash
|
||||
# デフォルト(small)より良いモデルを使用
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
# ダイアライゼーションと言語を指定した高度な設定
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
await websocket.accept()
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
||||
|
||||
|
||||
## パラメータと設定
|
||||
|
||||
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
||||
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
||||
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
||||
- `--warmup-file`、もしあれば
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
||||
- `--diarization`、使用したい場合。
|
||||
|
||||
残りは推奨しません。しかし、以下があなたのオプションです。
|
||||
|
||||
| パラメータ | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisperモデルのサイズ。 | `small` |
|
||||
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
||||
| `--task` | `transcribe`または`translate` | `transcribe` |
|
||||
| `--backend` | 処理バックエンド | `simulstreaming` |
|
||||
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
||||
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
||||
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
||||
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
||||
| `--host` | サーバーホストアドレス | `localhost` |
|
||||
| `--port` | サーバーポート | `8000` |
|
||||
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
|
||||
|
||||
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
||||
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
||||
|
||||
|
||||
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
||||
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
||||
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
||||
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
||||
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
||||
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
||||
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
||||
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
||||
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
||||
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
||||
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
||||
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
||||
|
||||
| ダイアライゼーションオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | 話者識別を有効化 | `False` |
|
||||
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
||||
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
||||
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
||||
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
||||
>4. HuggingFaceでログイン: `huggingface-cli login`
|
||||
|
||||
### 🚀 デプロイガイド
|
||||
|
||||
WhisperLiveKitを本番環境にデプロイするには:
|
||||
|
||||
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
||||
```bash
|
||||
pip install uvicorn gunicorn
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
||||
|
||||
3. **Nginx設定** (本番環境で推奨):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
||||
|
||||
## 🐋 Docker
|
||||
|
||||
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
||||
|
||||
### 前提条件
|
||||
- Dockerがシステムにインストールされていること
|
||||
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
||||
|
||||
### クイックスタート
|
||||
|
||||
**GPUアクセラレーション付き (推奨):**
|
||||
```bash
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
**CPUのみ:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### 高度な使用法
|
||||
|
||||
**カスタム設定:**
|
||||
```bash
|
||||
# カスタムモデルと言語の例
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### メモリ要件
|
||||
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
||||
|
||||
|
||||
#### カスタマイズ
|
||||
|
||||
- `--build-arg` オプション:
|
||||
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
||||
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
||||
|
||||
## 🔮 ユースケース
|
||||
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
||||
71
docs/alignement_principles.md
Normal file
71
docs/alignement_principles.md
Normal file
@@ -0,0 +1,71 @@
|
||||
### Alignment between STT Tokens and Diarization Segments
|
||||
|
||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and `t` prediction.
|
||||
|
||||
|
||||
## Example 1:
|
||||
```text
|
||||
punctuations_segments : __#_______.__________________!____
|
||||
diarization_segments:
|
||||
SPK1 __#____________
|
||||
SPK2 # ___________________
|
||||
-->
|
||||
ALIGNED SPK1 __#_______.
|
||||
ALIGNED SPK2 # __________________!____
|
||||
|
||||
t-1 output:
|
||||
SPK1: __#
|
||||
SPK2: NO
|
||||
DIARIZATION BUFFER: NO
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: __________________!____
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 2:
|
||||
```text
|
||||
punctuations_segments : _____#__.___________
|
||||
diarization_segments:
|
||||
SPK1 ___ #
|
||||
SPK2 __#______________
|
||||
-->
|
||||
ALIGNED SPK1 _____#__.
|
||||
ALIGNED SPK2 # ___________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___ #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: ___________
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 3:
|
||||
```text
|
||||
punctuations_segments : ___.__#__________
|
||||
diarization_segments:
|
||||
SPK1 ______#__
|
||||
SPK2 # ________
|
||||
-->
|
||||
ALIGNED SPK1 ___. #
|
||||
ALIGNED SPK2 __#__________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___. #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: #
|
||||
SPK2: __#___________
|
||||
DIARIZATION BUFFER: NO
|
||||
```
|
||||
43
docs/technical_integration.md
Normal file
43
docs/technical_integration.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Technical Integration Guide
|
||||
|
||||
This document introduce how to reuse the core components when you do **not** want to ship the bundled frontend, FastAPI server, or even the provided CLI.
|
||||
|
||||
---
|
||||
|
||||
## 1. Runtime Components
|
||||
|
||||
| Layer | File(s) | Purpose |
|
||||
|-------|---------|---------|
|
||||
| Transport | `whisperlivekit/basic_server.py`, any ASGI/WebSocket server | Accepts audio over WebSocket (MediaRecorder WebM or raw PCM chunks) and streams JSON updates back |
|
||||
| Audio processing | `whisperlivekit/audio_processor.py` | Buffers audio, orchestrates transcription, diarization, translation, handles FFmpeg/PCM input |
|
||||
| Engines | `whisperlivekit/core.py`, `whisperlivekit/simul_whisper/*`, `whisperlivekit/local_agreement/*` | Load models once (SimulStreaming or LocalAgreement), expose `TranscriptionEngine` and helpers |
|
||||
| Frontends | `whisperlivekit/web/*`, `chrome-extension/*` | Optional UI layers feeding the WebSocket endpoint |
|
||||
|
||||
**Key idea:** The server boundary is just `AudioProcessor.process_audio()` for incoming bytes and the async generator returned by `AudioProcessor.create_tasks()` for outgoing updates (`FrontData`). Everything else is optional.
|
||||
|
||||
---
|
||||
|
||||
## 2. Running Without the Bundled Frontend
|
||||
|
||||
1. Start the server/engine however you like:
|
||||
```bash
|
||||
wlk --model small --language en --host 0.0.0.0 --port 9000
|
||||
# or launch your own app that instantiates TranscriptionEngine(...)
|
||||
```
|
||||
2. Build your own client (browser, mobile, desktop) that:
|
||||
- Opens `ws(s)://<host>:<port>/asr`
|
||||
- Sends either MediaRecorder/Opus WebM blobs **or** raw PCM (`--pcm-input` on the server tells the client to use the AudioWorklet).
|
||||
- Consumes the JSON payload defined in `docs/API.md`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Running Without FastAPI
|
||||
|
||||
`whisperlivekit/basic_server.py` is just an example. Any async framework works, as long as you:
|
||||
|
||||
1. Create a global `TranscriptionEngine` (expensive to initialize; reuse it).
|
||||
2. Instantiate `AudioProcessor(transcription_engine=engine)` for each connection.
|
||||
3. Call `create_tasks()` to get the async generator, `process_audio()` with incoming bytes, and ensure `cleanup()` runs when the client disconnects.
|
||||
|
||||
|
||||
If you prefer to send compressed audio, instantiate `AudioProcessor(pcm_input=False)` and pipe encoded chunks through `FFmpegManager` transparently—just ensure `ffmpeg` is available or be ready to handle the `"ffmpeg_not_found"` error in the streamed `FrontData`.
|
||||
113
docs/troubleshooting.md
Normal file
113
docs/troubleshooting.md
Normal file
@@ -0,0 +1,113 @@
|
||||
# Troubleshooting
|
||||
|
||||
|
||||
## GPU drivers & cuDNN visibility
|
||||
|
||||
### Linux error: `Unable to load libcudnn_ops.so* / cudnnCreateTensorDescriptor`
|
||||
> Reported in issue #271 (Arch/CachyOS)
|
||||
|
||||
`faster-whisper` (used for the SimulStreaming encoder) dynamically loads cuDNN.
|
||||
If the runtime cannot find `libcudnn_*`, verify that CUDA and cuDNN match the PyTorch build you installed:
|
||||
|
||||
1. **Install CUDA + cuDNN** (Arch/CachyOS example):
|
||||
```bash
|
||||
sudo pacman -S cuda cudnn
|
||||
sudo ldconfig
|
||||
```
|
||||
2. **Make sure the shared objects are visible**:
|
||||
```bash
|
||||
ls /usr/lib/libcudnn*
|
||||
```
|
||||
3. **Check what CUDA version PyTorch expects** and match that with the driver you installed:
|
||||
```bash
|
||||
python - <<'EOF'
|
||||
import torch
|
||||
print(torch.version.cuda)
|
||||
EOF
|
||||
nvcc --version
|
||||
```
|
||||
4. If you installed CUDA in a non-default location, export `CUDA_HOME` and add `$CUDA_HOME/lib64` to `LD_LIBRARY_PATH`.
|
||||
|
||||
Once the CUDA/cuDNN versions match, `whisperlivekit-server` starts normally.
|
||||
|
||||
### Windows error: `Could not locate cudnn_ops64_9.dll`
|
||||
> Reported in issue #286 (Conda on Windows)
|
||||
|
||||
PyTorch bundles cuDNN DLLs inside your environment (`<env>\Lib\site-packages\torch\lib`).
|
||||
When `ctranslate2` or `faster-whisper` cannot find `cudnn_ops64_9.dll`:
|
||||
|
||||
1. Locate the DLL shipped with PyTorch, e.g.
|
||||
```
|
||||
E:\conda\envs\WhisperLiveKit\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
|
||||
```
|
||||
2. Add that directory to your `PATH` **or** copy the `cudnn_*64_9.dll` files into a directory that is already on `PATH` (such as the environment's `Scripts/` folder).
|
||||
3. Restart the shell before launching `wlk`.
|
||||
|
||||
Installing NVIDIA's standalone cuDNN 9.x and pointing `PATH`/`CUDNN_PATH` to it works as well, but is usually not required.
|
||||
|
||||
---
|
||||
|
||||
## PyTorch / CTranslate2 GPU builds
|
||||
|
||||
### `Torch not compiled with CUDA enabled`
|
||||
> Reported in issue #284
|
||||
|
||||
If `torch.zeros(1).cuda()` raises that assertion it means you installed a CPU-only wheel.
|
||||
Install the GPU-enabled wheels that match your CUDA toolkit:
|
||||
|
||||
```bash
|
||||
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
|
||||
```
|
||||
|
||||
Replace `cu130` with the CUDA version supported by your driver (see [PyTorch install selector](https://pytorch.org/get-started/locally/)).
|
||||
Validate with:
|
||||
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.is_available(), torch.cuda.get_device_name())
|
||||
```
|
||||
|
||||
### `CTranslate2 device count: 0` or `Could not infer dtype of ctranslate2._ext.StorageView`
|
||||
> Follow-up in issue #284
|
||||
|
||||
`ctranslate2` publishes separate CPU and CUDA wheels. The default `pip install ctranslate2` brings the CPU build, which makes WhisperLiveKit fall back to CPU tensors and leads to the dtype error above.
|
||||
|
||||
1. Uninstall the CPU build: `pip uninstall -y ctranslate2`.
|
||||
2. Install the CUDA wheel that matches your toolkit (example for CUDA 13.0):
|
||||
```bash
|
||||
pip install ctranslate2==4.5.0 -f https://opennmt.net/ctranslate2/whl/cu130
|
||||
```
|
||||
(See the [CTranslate2 installation table](https://opennmt.net/CTranslate2/installation.html) for other CUDA versions.)
|
||||
3. Verify:
|
||||
```python
|
||||
import ctranslate2
|
||||
print("CUDA devices:", ctranslate2.get_cuda_device_count())
|
||||
```
|
||||
|
||||
If you intentionally want CPU inference, run `wlk --backend whisper` to avoid mixing CPU-only CTranslate2 with a GPU Torch build.
|
||||
|
||||
---
|
||||
|
||||
## Hopper / Blackwell (`sm_121a`) systems
|
||||
> Reported in issue #276 (NVIDIA DGX Spark)
|
||||
|
||||
CUDA 12.1a GPUs ship before some toolchains know about the architecture ID, so Triton/PTXAS need manual hints:
|
||||
|
||||
```bash
|
||||
export CUDA_HOME="/usr/local/cuda-13.0"
|
||||
export PATH="$CUDA_HOME/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
||||
|
||||
# Tell Triton where the new ptxas lives
|
||||
export TRITON_PTXAS_PATH="$CUDA_HOME/bin/ptxas"
|
||||
|
||||
# Force PyTorch to JIT kernels for all needed architectures
|
||||
export TORCH_CUDA_ARCH_LIST="8.0 9.0 10.0 12.0 12.1a"
|
||||
```
|
||||
|
||||
After exporting those variables (or adding them to your systemd service / shell profile), restart `wlk`. Incoming streams will now compile kernels targeting `sm_121a` without crashing.
|
||||
|
||||
---
|
||||
|
||||
Need help with another recurring issue? Open a GitHub discussion or PR and reference this document so we can keep it current.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.14"
|
||||
version = "0.2.15"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -61,10 +61,10 @@ packages = [
|
||||
"whisperlivekit.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.local_agreement",
|
||||
"whisperlivekit.vad_models"
|
||||
"whisperlivekit.silero_vad_models"
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
@@ -14,10 +14,10 @@ from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
|
||||
@@ -5,16 +5,18 @@ import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import math
|
||||
import pathlib
|
||||
import sys
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio, load_dataset
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio
|
||||
from datasets import load_dataset
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
import shutil
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sync_extension_files():
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .audio_processor import AudioProcessor
|
||||
from .core import TranscriptionEngine
|
||||
from .parse_args import parse_args
|
||||
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
||||
from .web.web_interface import get_inline_ui_html, get_web_interface_html
|
||||
|
||||
__all__ = [
|
||||
"TranscriptionEngine",
|
||||
|
||||
@@ -1,43 +1,52 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
from time import time
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.core import (TranscriptionEngine,
|
||||
online_diarization_factory, online_factory,
|
||||
online_translation_factory)
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||
Line, Silence, State, Transcript)
|
||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
cut_sample = int(cut_sec * 16000)
|
||||
async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.ndarray, List[Any]]:
|
||||
items: List[Any] = []
|
||||
|
||||
first_item = await queue.get()
|
||||
queue.task_done()
|
||||
if first_item is SENTINEL:
|
||||
return first_item
|
||||
if isinstance(first_item, Silence):
|
||||
return first_item
|
||||
items.append(first_item)
|
||||
|
||||
for ind, pcm_array in enumerate(cumulative_pcm):
|
||||
if (cumulative_len + len(pcm_array)) >= cut_sample:
|
||||
cut_chunk = cut_sample - cumulative_len
|
||||
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
|
||||
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
|
||||
return before, after
|
||||
cumulative_len += len(pcm_array)
|
||||
return np.concatenate(cumulative_pcm), []
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
return items
|
||||
while True:
|
||||
if not queue._queue:
|
||||
break
|
||||
next_item = queue._queue[0]
|
||||
if next_item is SENTINEL:
|
||||
break
|
||||
if isinstance(next_item, Silence):
|
||||
break
|
||||
items.append(await queue.get())
|
||||
queue.task_done()
|
||||
if isinstance(items[0], np.ndarray):
|
||||
return np.concatenate(items)
|
||||
else: #translation
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
@@ -45,7 +54,7 @@ class AudioProcessor:
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
@@ -64,36 +73,27 @@ class AudioProcessor:
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.state = State()
|
||||
self.lock = asyncio.Lock()
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
self.diarization_before_transcription = False
|
||||
self.is_stopping: bool = False
|
||||
self.current_silence: Optional[Silence] = None
|
||||
self.state: State = State()
|
||||
self.lock: asyncio.Lock = asyncio.Lock()
|
||||
self.sep: str = " " # Default separator
|
||||
self.last_response_content: FrontData = FrontData()
|
||||
|
||||
self.segments = []
|
||||
|
||||
self.tokens_alignment: TokensAlignment = TokensAlignment(self.state, self.args, self.sep)
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
if self.diarization_before_transcription:
|
||||
self.cumulative_pcm = []
|
||||
self.last_start = 0.0
|
||||
self.last_end = 0.0
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.vac_model = models.vac_model
|
||||
self.asr: Any = models.asr
|
||||
self.vac_model: Any = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
self.vac: Optional[FixedVADIterator] = None
|
||||
|
||||
self.ffmpeg_manager = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self._ffmpeg_error = None
|
||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||
self._ffmpeg_error: Optional[str] = None
|
||||
|
||||
if not self.is_pcm_input:
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
@@ -105,20 +105,20 @@ class AudioProcessor:
|
||||
self._ffmpeg_error = error_type
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.translation_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer: bytearray = bytearray()
|
||||
self.total_pcm_samples: int = 0
|
||||
self.transcription_task: Optional[asyncio.Task] = None
|
||||
self.diarization_task: Optional[asyncio.Task] = None
|
||||
self.translation_task: Optional[asyncio.Task] = None
|
||||
self.watchdog_task: Optional[asyncio.Task] = None
|
||||
self.all_tasks_for_cleanup: List[asyncio.Task] = []
|
||||
|
||||
self.transcription = None
|
||||
self.translation = None
|
||||
self.diarization = None
|
||||
self.transcription: Optional[Any] = None
|
||||
self.translation: Optional[Any] = None
|
||||
self.diarization: Optional[Any] = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.transcription = online_factory(self.args, models.asr)
|
||||
@@ -128,27 +128,67 @@ class AudioProcessor:
|
||||
if models.translation_model:
|
||||
self.translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
async def _push_silence_event(self) -> None:
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(self.current_silence)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(self.current_silence)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(self.current_silence)
|
||||
|
||||
async def _begin_silence(self) -> None:
|
||||
if self.current_silence:
|
||||
return
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence = Silence(
|
||||
is_starting=True, start=now
|
||||
)
|
||||
await self._push_silence_event()
|
||||
|
||||
async def _end_silence(self) -> None:
|
||||
if not self.current_silence:
|
||||
return
|
||||
now = time() - self.beg_loop
|
||||
self.current_silence.end = now
|
||||
self.current_silence.is_starting=False
|
||||
self.current_silence.has_ended=True
|
||||
self.current_silence.compute_duration()
|
||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||
self.state.new_tokens.append(self.current_silence)
|
||||
await self._push_silence_event()
|
||||
self.current_silence = None
|
||||
|
||||
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray) -> None:
|
||||
if pcm_chunk is None or pcm_chunk.size == 0:
|
||||
return
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_chunk.copy())
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_chunk.copy())
|
||||
|
||||
def _slice_before_silence(self, pcm_array: np.ndarray, chunk_sample_start: int, silence_sample: Optional[int]) -> Optional[np.ndarray]:
|
||||
if silence_sample is None:
|
||||
return None
|
||||
relative_index = int(silence_sample - chunk_sample_start)
|
||||
if relative_index <= 0:
|
||||
return None
|
||||
split_index = min(relative_index, len(pcm_array))
|
||||
if split_index <= 0:
|
||||
return None
|
||||
return pcm_array[:split_index]
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.state.beg_loop
|
||||
self.state.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
))
|
||||
|
||||
async def get_current_state(self):
|
||||
async def get_current_state(self) -> State:
|
||||
"""Get current state."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
remaining_transcription = 0
|
||||
if self.state.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1))
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.state.tokens:
|
||||
@@ -160,7 +200,7 @@ class AudioProcessor:
|
||||
|
||||
return self.state
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
async def ffmpeg_stdout_reader(self) -> None:
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
beg = time()
|
||||
while True:
|
||||
@@ -203,50 +243,60 @@ class AudioProcessor:
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
if self.diarization:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self):
|
||||
async def transcription_processor(self) -> None:
|
||||
"""Process audio chunks for transcription."""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = await self.transcription_queue.get()
|
||||
# item = await self.transcription_queue.get()
|
||||
item = await get_all_from_queue(self.transcription_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer)
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.state.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
new_tokens = []
|
||||
current_audio_processed_upto = self.state.end_buffer
|
||||
|
||||
if isinstance(item, Silence):
|
||||
if item.is_starting:
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
|
||||
self.transcription.start_silence
|
||||
)
|
||||
asr_processing_logs += f" + Silence starting"
|
||||
if item.has_ended:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
current_audio_processed_upto = cumulative_pcm_duration_stream_time
|
||||
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
if self.state.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
continue
|
||||
new_tokens = new_tokens or []
|
||||
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.transcription.new_speaker(item)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||
new_tokens = new_tokens or []
|
||||
|
||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||
|
||||
_buffer_transcript = self.transcription.get_buffer()
|
||||
buffer_text = _buffer_transcript.text
|
||||
|
||||
@@ -269,13 +319,12 @@ class AudioProcessor:
|
||||
self.state.tokens.extend(new_tokens)
|
||||
self.state.buffer_transcription = _buffer_transcript
|
||||
self.state.end_buffer = max(candidate_end_times)
|
||||
|
||||
self.state.new_tokens.extend(new_tokens)
|
||||
self.state.new_tokens_buffer = _buffer_transcript
|
||||
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
await self.translation_queue.put(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
@@ -292,124 +341,57 @@ class AudioProcessor:
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
if self.diarization_before_transcription:
|
||||
self.current_speaker = 0
|
||||
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0))
|
||||
async def diarization_processor(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
item = await self.diarization_queue.get()
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
if item.has_ended:
|
||||
self.diarization.insert_silence(item.duration)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
if self.diarization_before_transcription:
|
||||
segments = diarization_obj.get_segments()
|
||||
self.cumulative_pcm.append(pcm_array)
|
||||
if segments:
|
||||
last_segment = segments[-1]
|
||||
if last_segment.speaker != self.current_speaker:
|
||||
cut_sec = last_segment.start - self.last_end
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
|
||||
self.current_speaker = last_segment.speaker
|
||||
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start))
|
||||
|
||||
cut_sec = last_segment.end - last_segment.start
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
self.last_start = last_segment.start
|
||||
self.last_end = last_segment.end
|
||||
else:
|
||||
cut_sec = last_segment.end - self.last_end
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
self.last_end = last_segment.end
|
||||
elif not self.diarization_before_transcription:
|
||||
async with self.lock:
|
||||
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.state.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.state.tokens) > 0:
|
||||
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
self.state.new_diarization = diarization_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self):
|
||||
async def translation_processor(self) -> None:
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
item = await self.translation_queue.get() #block until at least 1 token
|
||||
item = await get_all_from_queue(self.translation_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
self.translation.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [item]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
for additional_token in additional_tokens:
|
||||
if additional_token is SENTINEL:
|
||||
sentinel_found = True
|
||||
break
|
||||
elif type(additional_token) is Silence:
|
||||
self.translation.insert_silence(additional_token.duration)
|
||||
if item.is_starting:
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
if item.has_ended:
|
||||
self.translation.insert_silence(item.duration)
|
||||
continue
|
||||
else:
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.translation_validated_segments = translation_validated_segments
|
||||
self.state.buffer_translation = buffer_translation
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
|
||||
if sentinel_found:
|
||||
logger.debug("Translation processor received sentinel in batch. Finishing.")
|
||||
break
|
||||
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
new_translation, new_translation_buffer = self.translation.validate_buffer_and_reset()
|
||||
pass
|
||||
else:
|
||||
self.translation.insert_tokens(item)
|
||||
new_translation, new_translation_buffer = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.new_translation.append(new_translation)
|
||||
self.state.new_translation_buffer = new_translation_buffer
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and item is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
async def results_formatter(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Format processing results for output."""
|
||||
while True:
|
||||
try:
|
||||
@@ -419,55 +401,32 @@ class AudioProcessor:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
state = await self.get_current_state()
|
||||
|
||||
lines, undiarized_text = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
args = self.args,
|
||||
sep=self.sep
|
||||
self.tokens_alignment.update()
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=bool(self.translation),
|
||||
current_silence=self.current_silence
|
||||
)
|
||||
if lines and lines[-1].speaker == -2:
|
||||
buffer_transcription = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
state = await self.get_current_state()
|
||||
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
buffer_transcription_text = state.buffer_transcription.text if state.buffer_transcription else ''
|
||||
|
||||
async with self.lock:
|
||||
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
buffer_translation_text = ''
|
||||
if state.buffer_translation:
|
||||
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||
if raw_buffer_translation:
|
||||
buffer_translation_text = raw_buffer_translation.strip()
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.end_buffer,
|
||||
end=state.end_buffer
|
||||
)]
|
||||
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_transcription=buffer_transcription_text,
|
||||
buffer_diarization=buffer_diarization_text,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
if should_push:
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
@@ -481,17 +440,17 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
|
||||
"""Create and start processing tasks."""
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
processing_tasks_for_watchdog: List[asyncio.Task] = []
|
||||
|
||||
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||
if not self.is_pcm_input:
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator():
|
||||
async def error_generator() -> AsyncGenerator[FrontData, None]:
|
||||
yield FrontData(
|
||||
status="error",
|
||||
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||
@@ -507,7 +466,7 @@ class AudioProcessor:
|
||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||
|
||||
if self.diarization:
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor())
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
@@ -522,9 +481,9 @@ class AudioProcessor:
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
tasks_remaining = [task for task in tasks_to_monitor if task]
|
||||
tasks_remaining: List[asyncio.Task] = [task for task in tasks_to_monitor if task]
|
||||
while True:
|
||||
try:
|
||||
if not tasks_remaining:
|
||||
@@ -549,7 +508,7 @@ class AudioProcessor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self):
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
@@ -572,7 +531,7 @@ class AudioProcessor:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
def _processing_tasks_done(self):
|
||||
def _processing_tasks_done(self) -> bool:
|
||||
"""Return True when all active processing tasks have completed."""
|
||||
tasks_to_check = [
|
||||
self.transcription_task,
|
||||
@@ -583,11 +542,13 @@ class AudioProcessor:
|
||||
return all(task.done() for task in tasks_to_check if task)
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
async def process_audio(self, message: Optional[bytes]) -> None:
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.state.beg_loop:
|
||||
self.state.beg_loop = time()
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
self.current_silence = Silence(start=0.0, is_starting=True)
|
||||
self.tokens_alignment.beg_loop = self.beg_loop
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
@@ -620,7 +581,7 @@ class AudioProcessor:
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
async def handle_pcm_data(self) -> None:
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
@@ -639,40 +600,30 @@ class AudioProcessor:
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
||||
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
|
||||
|
||||
res = None
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
num_samples = len(pcm_array)
|
||||
chunk_sample_start = self.total_pcm_samples
|
||||
chunk_sample_end = chunk_sample_start + num_samples
|
||||
|
||||
res = None
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
if res.get("end", 0) > res.get("start", 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
if "start" in res and self.current_silence:
|
||||
await self._end_silence()
|
||||
|
||||
if "end" in res and not self.current_silence:
|
||||
pre_silence_chunk = self._slice_before_silence(
|
||||
pcm_array, chunk_sample_start, res.get("end")
|
||||
)
|
||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||
await self._enqueue_active_audio(pre_silence_chunk)
|
||||
await self._begin_silence()
|
||||
|
||||
if silence_buffer:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
if not self.current_silence:
|
||||
await self._enqueue_active_audio(pcm_array)
|
||||
|
||||
if not self.silence:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
self.total_pcm_samples = chunk_sample_end
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
||||
get_inline_ui_html, parse_args)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
@@ -52,8 +54,8 @@ class TranscriptionEngine:
|
||||
|
||||
transcription_common_params = {
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.5,
|
||||
"model_size": "tiny",
|
||||
"min_chunk_size": 0.1,
|
||||
"model_size": "base",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
@@ -80,6 +82,7 @@ class TranscriptionEngine:
|
||||
|
||||
if self.args.vac:
|
||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||
|
||||
# Use ONNX if specified, otherwise use JIT (default)
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
@@ -100,7 +103,6 @@ class TranscriptionEngine:
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
@@ -135,7 +137,8 @@ class TranscriptionEngine:
|
||||
|
||||
if self.args.diarization:
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
from whisperlivekit.diarization.diart_backend import \
|
||||
DiartDiarization
|
||||
diart_params = {
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
@@ -146,7 +149,8 @@ class TranscriptionEngine:
|
||||
**diart_params
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
from whisperlivekit.diarization.sortformer_backend import \
|
||||
SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
|
||||
self.translation_model = None
|
||||
@@ -182,7 +186,8 @@ def online_diarization_factory(args, diarization_backend):
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||
|
||||
if args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||
from whisperlivekit.diarization.sortformer_backend import \
|
||||
SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
return online
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
from queue import SimpleQueue, Empty
|
||||
from queue import Empty, SimpleQueue
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import diart.models as m
|
||||
import numpy as np
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.inference import StreamingInference
|
||||
from diart.sources import AudioSource
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
from diart.sources import MicrophoneAudioSource
|
||||
from rx.core import Observer
|
||||
from typing import Tuple, Any, List
|
||||
from diart.sources import AudioSource, MicrophoneAudioSource
|
||||
from pyannote.core import Annotation
|
||||
import diart.models as m
|
||||
from rx.core import Observer
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,7 +26,7 @@ class DiarizationObserver(Observer):
|
||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||
|
||||
def __init__(self):
|
||||
self.speaker_segments = []
|
||||
self.diarization_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
@@ -48,7 +48,7 @@ class DiarizationObserver(Observer):
|
||||
for speaker, label in annotation._labels.items():
|
||||
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
self.diarization_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
@@ -59,14 +59,14 @@ class DiarizationObserver(Observer):
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
self.diarization_segments = [
|
||||
segment for segment in self.diarization_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
|
||||
@@ -178,7 +178,6 @@ class DiartDiarization:
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
@@ -217,32 +216,6 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
|
||||
if not use_punctuation_split:
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
else:
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from queue import Empty, SimpleQueue
|
||||
from typing import List, Optional
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
@@ -94,11 +95,11 @@ class SortformerDiarizationOnline:
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.speaker_segments = []
|
||||
self.diarization_segments = []
|
||||
self.diar_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.processed_time = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
@@ -155,12 +156,10 @@ class SortformerDiarizationOnline:
|
||||
)
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
|
||||
# Initialize total predictions tensor
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
def insert_silence(self, silence_duration: Optional[float]):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
@@ -171,248 +170,111 @@ class SortformerDiarizationOnline:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
|
||||
|
||||
async def diarize(self):
|
||||
"""
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
try:
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return []
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
|
||||
# Convert predictions to speaker segments
|
||||
self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in diarize: {e}")
|
||||
raise
|
||||
|
||||
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
new_segments = self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
return new_segments
|
||||
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
try:
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers)
|
||||
|
||||
# Get predictions for current chunk
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
with self.segment_lock:
|
||||
# Process predictions into segments
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
start_time = base_time + idx * frame_duration
|
||||
end_time = base_time + (idx + 1) * frame_duration
|
||||
|
||||
# Check if this continues the last segment or starts a new one
|
||||
if (self.speaker_segments and
|
||||
self.speaker_segments[-1].speaker == spk and
|
||||
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
||||
# Continue existing segment
|
||||
self.speaker_segments[-1].end = end_time
|
||||
else:
|
||||
|
||||
# Create new segment
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=spk,
|
||||
start=start_time,
|
||||
end=end_time
|
||||
))
|
||||
|
||||
# Update processed time
|
||||
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||
|
||||
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing predictions: {e}")
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
Last speaker_segment
|
||||
"""
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers) #12
|
||||
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
new_segments = []
|
||||
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
tokens: List of tokens to assign speakers to
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
|
||||
# Convert segments to concatenated format
|
||||
segments_concatenated = self._concatenate_speakers(segments)
|
||||
|
||||
# Adjust segment boundaries based on punctuation
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""
|
||||
Concatenate consecutive segments from the same speaker.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
|
||||
Returns:
|
||||
List of concatenated speaker segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = segment.speaker + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
current_spk = current_chunk_preds[0]
|
||||
start_time = round(base_time, 2)
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
current_time = round(base_time + idx * frame_duration, 2)
|
||||
if spk != current_spk:
|
||||
new_segments.append(SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
))
|
||||
start_time = current_time
|
||||
current_spk = spk
|
||||
new_segments.append(
|
||||
SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
)
|
||||
)
|
||||
return new_segments
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def close(self):
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.speaker_segments.clear()
|
||||
self.diarization_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
@@ -434,11 +296,12 @@ def extract_number(s: str) -> int:
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
import librosa
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'audio_test.mp3'
|
||||
an4_audio = 'diarization_audio.wav'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
@@ -450,13 +313,15 @@ if __name__ == '__main__':
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
diarization_backend = SortformerDiarization()
|
||||
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
await diarization.diarize(chunk)
|
||||
new_segments = await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
print(new_segments)
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
import librosa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model():
|
||||
|
||||
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
diar_model.to(torch.device("cuda"))
|
||||
|
||||
#we target 1 second lag for the moment. chunk_len could be reduced.
|
||||
diar_model.sortformer_modules.chunk_len = 10
|
||||
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
|
||||
|
||||
diar_model.sortformer_modules.chunk_right_context = 0 #no.
|
||||
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
|
||||
|
||||
diar_model.sortformer_modules.spkcache_len = 188
|
||||
diar_model.sortformer_modules.fifo_len = 188
|
||||
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
diar_model.sortformer_modules.log = False
|
||||
diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
|
||||
audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
|
||||
|
||||
return diar_model, audio2mel
|
||||
|
||||
diar_model, audio2mel = load_model()
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
spkcache = None # Speaker cache to store embeddings from start
|
||||
spkcache_lengths = None #
|
||||
spkcache_preds = None # speaker cache predictions
|
||||
fifo = None # to save the embedding from the latest chunks
|
||||
fifo_lengths = None
|
||||
fifo_preds = None
|
||||
spk_perm = None
|
||||
mean_sil_emb = None
|
||||
n_sil_frames = None
|
||||
|
||||
|
||||
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||
"""
|
||||
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size for tensors in streaming state
|
||||
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||
device (torch.device): Device for tensors in streaming state
|
||||
|
||||
Returns:
|
||||
streaming_state (SortformerStreamingState): initialized streaming state
|
||||
"""
|
||||
streaming_state = StreamingSortformerState()
|
||||
if async_streaming:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
else:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
return streaming_state
|
||||
|
||||
|
||||
def process_diarization(chunks):
|
||||
"""
|
||||
what it does:
|
||||
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
|
||||
2. STFT: Computes the Short-Time Fourier Transform using:
|
||||
- the window of window_size=0.025 --> size of a window : 400 samples
|
||||
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
|
||||
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
|
||||
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
|
||||
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
|
||||
6. Normalization: Skips normalization since `normalize="NA"`
|
||||
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
|
||||
"""
|
||||
previous_chunk = None
|
||||
l_chunk_feat_seq_t = []
|
||||
for chunk in chunks:
|
||||
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
|
||||
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
|
||||
if previous_chunk is not None:
|
||||
to_add = previous_chunk[:, :, -99:]
|
||||
total = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total = processed_signal_chunk
|
||||
previous_chunk = processed_signal_chunk
|
||||
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
|
||||
|
||||
batch_size = 1
|
||||
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||
batch_size = batch_size,
|
||||
async_streaming = True,
|
||||
device = diar_model.device
|
||||
)
|
||||
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||
|
||||
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||
|
||||
l_speakers = [
|
||||
{'start_time': 0,
|
||||
'end_time': 0,
|
||||
'speaker': 0
|
||||
}
|
||||
]
|
||||
len_prediction = None
|
||||
left_offset = 0
|
||||
right_offset = 8
|
||||
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
|
||||
with torch.inference_mode():
|
||||
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=streaming_state,
|
||||
total_preds=total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
left_offset = 8
|
||||
preds_np = total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
if len_prediction is None:
|
||||
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||
frame_duration = chunk_duration_seconds / len_prediction
|
||||
active_speakers = active_speakers[-len_prediction:]
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
if spk != l_speakers[-1]['speaker']:
|
||||
l_speakers.append(
|
||||
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
|
||||
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
|
||||
'speaker': spk
|
||||
})
|
||||
else:
|
||||
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||
|
||||
|
||||
"""
|
||||
Should print
|
||||
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||
"""
|
||||
for speaker in l_speakers:
|
||||
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
# signal = signal[:-(len(signal)%16000)]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Expected ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
chunk_size = 16000 # 1 second
|
||||
chunks = []
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
|
||||
process_diarization(chunks)
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
import contextlib
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import sys
|
||||
import logging
|
||||
import io
|
||||
import soundfile as sf
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
@@ -165,8 +168,8 @@ class MLXWhisper(ASRBase):
|
||||
sep = ""
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
@@ -224,7 +227,8 @@ class MLXWhisper(ASRBase):
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||
probability=word["probability"]
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import sys
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -151,21 +153,32 @@ class OnlineASRProcessor:
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
# if self.transcript_buffer.buffer:
|
||||
# self.committed.extend(self.transcript_buffer.buffer)
|
||||
# self.transcript_buffer.buffer = []
|
||||
|
||||
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
|
||||
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
def start_silence(self):
|
||||
if self.audio_buffer.size == 0:
|
||||
return [], self.get_audio_buffer_end_time()
|
||||
return self.process_iter()
|
||||
|
||||
def end_silence(self, silence_duration: Optional[float], offset: float):
|
||||
if not silence_duration or silence_duration <= 0:
|
||||
return
|
||||
|
||||
long_silence = silence_duration >= 5
|
||||
if not long_silence:
|
||||
gap_samples = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_samples > 0:
|
||||
gap_silence = np.zeros(gap_samples, dtype=np.float32)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
else:
|
||||
self.init(offset=silence_duration + offset)
|
||||
|
||||
self.global_time_offset += silence_duration
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
Backwards compatibility shim for legacy callers that still use insert_silence.
|
||||
"""
|
||||
self.end_silence(silence_duration, offset)
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context), where:
|
||||
@@ -400,11 +413,11 @@ class OnlineASRProcessor:
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
# probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
return Transcript(start, end, text)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
import platform
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||
import sys
|
||||
import time
|
||||
from functools import lru_cache
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
@@ -81,14 +82,14 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=0.1,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="small",
|
||||
default="base",
|
||||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
@@ -295,14 +296,6 @@ def parse_args():
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--preload-model-count",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="preload_model_count",
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from time import time
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||
last_token = tokens[-1]
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
probability=0.95
|
||||
)
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
|
||||
import logging
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
CHECK_AROUND = 4
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.is_punctuation():
|
||||
return True
|
||||
return False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if is_punctuation(tokens[ind]):
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if is_punctuation(token):
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
):
|
||||
return Line(
|
||||
speaker = token.corrected_speaker,
|
||||
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token):
|
||||
if not lines:
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
previous_speaker = 1
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
if is_punctuation(token):
|
||||
state.last_punctuation_index = i
|
||||
|
||||
if state.last_punctuation_index == i-1:
|
||||
if token.speaker != previous_speaker:
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
last_punctuation = None
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != previous_speaker:
|
||||
if not (speaker == -2 or previous_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
previous_speaker = token.corrected_speaker
|
||||
|
||||
previous_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if int(token.corrected_speaker) != int(previous_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
previous_speaker = token.corrected_speaker
|
||||
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translation_validated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
|
||||
if state.buffer_transcription and lines:
|
||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||
|
||||
return lines, undiarized_text
|
||||
@@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
@@ -123,7 +124,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'vad_models'
|
||||
data_dir = current_dir / 'silero_vad_models'
|
||||
|
||||
if onnx:
|
||||
if opset_version == 16:
|
||||
@@ -138,7 +139,7 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
@@ -276,8 +277,10 @@ class FixedVADIterator(VADIterator):
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r and "end" in ret:
|
||||
del ret["end"]
|
||||
if "start" in r:
|
||||
ret["start"] = r["start"]
|
||||
if "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,30 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import gc
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import os
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
@@ -34,6 +33,8 @@ if HAS_FASTER_WHISPER:
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
@@ -48,49 +49,56 @@ class SimulStreamingOnlineProcessor:
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
self.load_new_alignatt_instance()
|
||||
|
||||
#can be moved
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
|
||||
def load_new_backend(self):
|
||||
model = self.asr.get_new_model_instance()
|
||||
self.model = PaddedAlignAttWhisper(
|
||||
def load_new_alignatt_instance(self):
|
||||
"""Initialize AlignAtt decoder using the shared model."""
|
||||
self.model = AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=model,
|
||||
loaded_model=self.asr.shared_model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
def start_silence(self):
|
||||
tokens, processed_upto = self.process_iter(is_last=True)
|
||||
return tokens, processed_upto
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
Handle silence period.
|
||||
|
||||
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
||||
Otherwise, insert a small silence and shift the last_attend_frame.
|
||||
"""
|
||||
if silence_duration < 5:
|
||||
gap_silence = torch.zeros(int(16000*silence_duration))
|
||||
self.model.insert_audio(gap_silence)
|
||||
# self.global_time_offset += silence_duration
|
||||
else:
|
||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(16000 * silence_duration)
|
||||
if gap_len > 0:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
self.model.insert_audio(gap_silence)
|
||||
if long_silence:
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.global_time_offset = change_speaker.start
|
||||
"""Handle speaker change event."""
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.model.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
@@ -104,15 +112,17 @@ class SimulStreamingOnlineProcessor:
|
||||
"""
|
||||
try:
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
||||
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming processing error: {e}")
|
||||
return [], self.end
|
||||
@@ -128,12 +138,8 @@ class SimulStreamingOnlineProcessor:
|
||||
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
# free the model and add a new model to stack.
|
||||
# del self.model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# self.asr.new_model_to_stack()
|
||||
self.model.remove_hooks()
|
||||
|
||||
class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
@@ -218,10 +224,7 @@ class SimulStreamingASR():
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
@@ -245,8 +248,7 @@ class SimulStreamingASR():
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
self.shared_model = self.load_model()
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
@@ -295,39 +297,21 @@ class SimulStreamingASR():
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
)
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = PaddedAlignAttWhisper(
|
||||
if self.fast_encoder:
|
||||
temp_model = AlignAtt(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
fw_encoder=self.fw_encoder,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
temp_model.remove_hooks()
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
"""
|
||||
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
||||
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
||||
"""
|
||||
if len(self.models) == 0:
|
||||
self.models.append(self.load_model())
|
||||
new_model = self.models.pop()
|
||||
return new_model
|
||||
# self.models[0]
|
||||
|
||||
def new_model_to_stack(self):
|
||||
self.models.append(self.load_model())
|
||||
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
from torch import Tensor
|
||||
|
||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||
|
||||
# extention of PyTorchInference for beam search
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
|
||||
def _kv_modules(self):
|
||||
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
||||
return key_modules + value_modules
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
"""Extension of PyTorchInference for beam search with cross-attention support."""
|
||||
|
||||
def _kv_cache_ids(self):
|
||||
"""Get cache_id strings for self-attention key/value modules."""
|
||||
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
|
||||
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
|
||||
return key_ids + value_ids
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
for module_cache_id in self._kv_modules():
|
||||
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
||||
from torch import Tensor
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
for cache_id in self._kv_cache_ids():
|
||||
if cache_id in self.kv_cache:
|
||||
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: Tensor,
|
||||
audio_features: Tensor,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
"""Get logits, optionally returning cross-attention weights."""
|
||||
return self.model.decoder(
|
||||
tokens, audio_features,
|
||||
kv_cache=self.kv_cache,
|
||||
return_cross_attn=return_cross_attn,
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
|
||||
80
whisperlivekit/simul_whisper/decoder_state.py
Normal file
80
whisperlivekit/simul_whisper/decoder_state.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecoderState:
|
||||
|
||||
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
|
||||
tokenizer: Any = None
|
||||
detected_language: Optional[str] = None
|
||||
reset_tokenizer_to_auto_next_call: bool = False
|
||||
|
||||
tokens: List[torch.Tensor] = field(default_factory=list)
|
||||
initial_tokens: Optional[torch.Tensor] = None
|
||||
initial_token_length: int = 0
|
||||
sot_index: int = 0
|
||||
|
||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||
num_align_heads: int = 0
|
||||
|
||||
segments: List[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
context: Any = None
|
||||
|
||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||
|
||||
global_time_offset: float = 0.0
|
||||
cumulative_time_offset: float = 0.0
|
||||
first_timestamp: Optional[float] = None
|
||||
last_attend_frame: int = 0
|
||||
|
||||
speaker: int = -1
|
||||
log_segments: int = 0
|
||||
|
||||
CIFLinear: Optional[torch.nn.Module] = None
|
||||
always_fire: bool = False
|
||||
never_fire: bool = False
|
||||
|
||||
suppress_tokens_fn: Any = None
|
||||
|
||||
token_decoder: Any = None
|
||||
decoder_type: str = "greedy"
|
||||
|
||||
inference: Any = None
|
||||
|
||||
def clean_cache(self):
|
||||
"""Clean the kv_cache after each inference step."""
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam" and self.inference is not None:
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
if self.token_decoder is not None:
|
||||
self.token_decoder.reset()
|
||||
|
||||
def reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Reset transient state for a new segment.
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
self.last_attend_frame = -rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.pending_incomplete_tokens = []
|
||||
self.log_segments += 1
|
||||
|
||||
def full_reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Full reset including audio segments and tokens.
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
self.reset(rewind_threshold)
|
||||
self.segments = []
|
||||
self.tokens = []
|
||||
self.kv_cache = {}
|
||||
self.first_timestamp = None
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
class Tokens:
|
||||
def __init__(self, tokens):
|
||||
self.tokens = tokens
|
||||
|
||||
# def clone(self):
|
||||
# return Tokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.tokens.tolist())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class BeamTokens(Tokens):
|
||||
def __init__(self, tokens, beam_size):
|
||||
self.tokens = tokens
|
||||
self.beam_size = beam_size
|
||||
|
||||
def clone(self):
|
||||
return BeamTokens(self.tokens.clone())
|
||||
|
||||
def __str__(self):
|
||||
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def as_text(self, tokenizer):
|
||||
return tokenizer.decode(self.tokens)
|
||||
|
||||
class Logits(Tokens):
|
||||
def __init__(self, logits):
|
||||
super().__init__(logits)
|
||||
|
||||
# def clone(self):
|
||||
# return Logits(self.tokens.clone(), self.beam_size)
|
||||
|
||||
def __str__(self):
|
||||
# return "abc"
|
||||
return f"Logits({self.tokens.shape})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from mlx_whisper import whisper
|
||||
|
||||
mlx_model_mapping = {
|
||||
|
||||
@@ -1,49 +1,84 @@
|
||||
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from whisperlivekit.whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||
TOKENS_PER_SECOND,
|
||||
log_mel_spectrogram, pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
||||
SuppressTokens)
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
from .generation_progress import *
|
||||
from .beam import BeamPyTorchInference
|
||||
from .config import AlignAttConfig
|
||||
from .decoder_state import DecoderState
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = False
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.audio import \
|
||||
log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
|
||||
class PaddedAlignAttWhisper:
|
||||
USE_MLCORE = False
|
||||
|
||||
|
||||
def load_coreml_encoder():
|
||||
try:
|
||||
from coremltools.models import MLModel
|
||||
except ImportError:
|
||||
logger.warning("coremltools is not installed")
|
||||
return None
|
||||
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
|
||||
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
||||
spec = _coreml_encoder.get_spec()
|
||||
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
||||
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
|
||||
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
||||
|
||||
|
||||
class AlignAtt:
|
||||
"""
|
||||
Alignment-based Attention decoder for SimulStreaming.
|
||||
|
||||
This class is now hookless - the model can be shared across multiple
|
||||
sessions, with each session maintaining its own DecoderState.
|
||||
"""
|
||||
|
||||
# Property accessors for backward compatibility
|
||||
@property
|
||||
def speaker(self):
|
||||
return self.state.speaker
|
||||
|
||||
@speaker.setter
|
||||
def speaker(self, value):
|
||||
self.state.speaker = value
|
||||
|
||||
@property
|
||||
def global_time_offset(self):
|
||||
return self.state.global_time_offset
|
||||
|
||||
@global_time_offset.setter
|
||||
def global_time_offset(self, value):
|
||||
self.state.global_time_offset = value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
@@ -51,130 +86,103 @@ class PaddedAlignAttWhisper:
|
||||
mlx_encoder=None,
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
|
||||
# Shared model reference (can be shared across sessions)
|
||||
self.model = loaded_model
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
self.coreml_encoder_tuple = None
|
||||
if USE_MLCORE:
|
||||
self.coreml_encoder_tuple = load_coreml_encoder()
|
||||
self.use_mlcore = self.coreml_encoder_tuple is not None
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
language=cfg.language,
|
||||
without_timestamps=True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
self.cfg = cfg
|
||||
self.l_hooks = []
|
||||
|
||||
# model to detect end-of-word boundary at the end of the segment
|
||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||
n_audio_state=self.model.dims.n_audio_state,
|
||||
device=self.model.device)
|
||||
|
||||
# install hooks to access encoder-decoder attention
|
||||
self.dec_attns = []
|
||||
def layer_hook(module, net_input, net_output):
|
||||
# net_output[1]: B*num_head*token_len*audio_len
|
||||
t = F.softmax(net_output[1], dim=-1)
|
||||
self.dec_attns.append(t.squeeze(0))
|
||||
for b in self.model.decoder.blocks:
|
||||
hook = b.cross_attn.register_forward_hook(layer_hook)
|
||||
self.l_hooks.append(hook)
|
||||
|
||||
self.kv_cache = {}
|
||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
||||
# save as-is, for the first token or cross attention
|
||||
self.kv_cache[module.cache_id] = net_output
|
||||
else:
|
||||
x = self.kv_cache[module.cache_id]
|
||||
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
||||
return self.kv_cache[module.cache_id]
|
||||
|
||||
for i,b in enumerate(self.model.decoder.blocks):
|
||||
hooks = [
|
||||
b.attn.key.register_forward_hook(kv_hook),
|
||||
b.attn.value.register_forward_hook(kv_hook),
|
||||
b.cross_attn.key.register_forward_hook(kv_hook),
|
||||
b.cross_attn.value.register_forward_hook(kv_hook),
|
||||
]
|
||||
self.l_hooks.extend(hooks)
|
||||
|
||||
self.align_source = {}
|
||||
self.num_align_heads = 0
|
||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||
layer_rank = layer_rank.item()
|
||||
heads = self.align_source.get(layer_rank, [])
|
||||
heads.append((self.num_align_heads, head_id.item()))
|
||||
self.align_source[layer_rank] = heads
|
||||
self.num_align_heads += 1
|
||||
|
||||
|
||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
# self.tokenizer.eot
|
||||
self.tokenizer.no_timestamps, # added by DM
|
||||
] + list(self.tokenizer.all_language_tokens) # added by DM
|
||||
if self.tokenizer.no_speech is not None:
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||
sup_tokens = SuppressTokens(suppress_tokens)
|
||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||
# blank tokens are suppresed for new segments near the line 334
|
||||
|
||||
# it's going to be regenerated after lang id
|
||||
self.segments = []
|
||||
self.init_tokens()
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.first_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = self.cfg.max_context_tokens
|
||||
|
||||
# Initialize per-session state
|
||||
self.state = DecoderState()
|
||||
self._init_state(cfg)
|
||||
|
||||
def _init_state(self, cfg: AlignAttConfig):
|
||||
"""Initialize the per-session decoder state."""
|
||||
# Create tokenizer
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
self.state.tokenizer = self.tokenizer
|
||||
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
|
||||
# Timing state
|
||||
self.state.global_time_offset = 0.0
|
||||
self.state.last_attend_frame = -cfg.rewind_threshold
|
||||
self.state.speaker = -1
|
||||
|
||||
# CIF helpers for end-of-word boundary detection
|
||||
self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif(
|
||||
cfg,
|
||||
n_audio_state=self.model.dims.n_audio_state,
|
||||
device=self.model.device
|
||||
)
|
||||
|
||||
# Build alignment source mapping from model's alignment_heads
|
||||
self.state.align_source = {}
|
||||
self.state.num_align_heads = 0
|
||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||
layer_rank = layer_rank.item()
|
||||
heads = self.state.align_source.get(layer_rank, [])
|
||||
heads.append((self.state.num_align_heads, head_id.item()))
|
||||
self.state.align_source[layer_rank] = heads
|
||||
self.state.num_align_heads += 1
|
||||
|
||||
# Build suppress tokens function
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
self.tokenizer.no_timestamps,
|
||||
] + list(self.tokenizer.all_language_tokens)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||
sup_tokens = SuppressTokens(suppress_tokens)
|
||||
self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None)
|
||||
|
||||
# Initialize tokens
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
|
||||
# decoder type: greedy or beam
|
||||
# Set up decoder type
|
||||
self.state.decoder_type = cfg.decoder_type
|
||||
if cfg.decoder_type == "greedy":
|
||||
logger.info("Using greedy decoder")
|
||||
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||
self.decoder_type = "greedy"
|
||||
|
||||
self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||
elif cfg.decoder_type == "beam":
|
||||
self.decoder_type = "beam"
|
||||
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def remove_hooks(self):
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
logger.info("Using beam decoder")
|
||||
self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length)
|
||||
self.state.inference.kv_cache = self.state.kv_cache
|
||||
self.state.token_decoder = BeamSearchDecoder(
|
||||
inference=self.state.inference,
|
||||
eot=self.tokenizer.eot,
|
||||
beam_size=cfg.beam_size
|
||||
)
|
||||
|
||||
def warmup(self, audio):
|
||||
try:
|
||||
@@ -192,96 +200,100 @@ class PaddedAlignAttWhisper:
|
||||
num_languages=self.model.num_languages,
|
||||
task=self.decode_options.task
|
||||
)
|
||||
self.state.tokenizer = self.tokenizer
|
||||
|
||||
def init_context(self):
|
||||
kw = {'tokenizer': self.tokenizer,
|
||||
'device': self.model.device,
|
||||
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
||||
self.context = TokenBuffer.empty(**kw)
|
||||
self.state.context = TokenBuffer.empty(**kw)
|
||||
if self.cfg.static_init_prompt is not None:
|
||||
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
if self.cfg.init_prompt is not None:
|
||||
self.context.text += self.cfg.init_prompt
|
||||
self.state.context.text += self.cfg.init_prompt
|
||||
|
||||
def init_tokens(self):
|
||||
logger.debug(f"init tokens, {len(self.segments)}")
|
||||
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||
# init tokens (mandatory prompt)
|
||||
self.initial_tokens = torch.tensor(
|
||||
self.state.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
self.initial_token_length = self.initial_tokens.shape[1]
|
||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
# self.segments = []
|
||||
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||
self.tokens = [self.initial_tokens]
|
||||
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||
self.state.tokens = [self.state.initial_tokens]
|
||||
|
||||
def trim_context(self):
|
||||
logger.info("Trimming context")
|
||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
||||
logger.info(f"Context text: {self.context.as_text()}")
|
||||
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
||||
l = sum(t.shape[1] for t in self.tokens) + c
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
||||
logger.info(f"Context text: {self.state.context.as_text()}")
|
||||
l = sum(t.shape[1] for t in self.state.tokens) + c
|
||||
if self.cfg.static_init_prompt is None:
|
||||
after = 0
|
||||
else:
|
||||
after = len(self.cfg.static_init_prompt)
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||
t = self.context.trim_words(after=after)
|
||||
t = self.state.context.trim_words(after=after)
|
||||
l -= t
|
||||
c -= t
|
||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if t == 0:
|
||||
break
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
||||
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
||||
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
if self.cfg.decoder_type == "greedy":
|
||||
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
def logits(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
audio_features: torch.Tensor,
|
||||
return_cross_attn: bool = False
|
||||
):
|
||||
"""Get logits from decoder, optionally returning cross-attention weights."""
|
||||
if self.state.decoder_type == "greedy":
|
||||
return self.model.decoder(
|
||||
tokens, audio_features,
|
||||
kv_cache=self.state.kv_cache,
|
||||
return_cross_attn=return_cross_attn
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Logits shape: {tokens.shape}")
|
||||
logit = self.inference.logits(tokens, audio_features)
|
||||
return logit
|
||||
return self.state.inference.logits(
|
||||
tokens, audio_features,
|
||||
return_cross_attn=return_cross_attn
|
||||
)
|
||||
|
||||
|
||||
def refresh_segment(self, complete=False):
|
||||
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.detected_language = None
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
self.segments = self.segments[-2:]
|
||||
logger.debug(f"Context: {self.state.context}")
|
||||
if not complete and len(self.state.segments) > 2:
|
||||
self.state.segments = self.state.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
self.segments = []
|
||||
self.log_segments += 1
|
||||
|
||||
self.pending_incomplete_tokens = []
|
||||
self.state.segments = []
|
||||
self.state.log_segments += 1
|
||||
self.state.pending_incomplete_tokens = []
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.always_fire: return True
|
||||
if self.never_fire: return False
|
||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||
|
||||
if self.state.always_fire:
|
||||
return True
|
||||
if self.state.never_fire:
|
||||
return False
|
||||
return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear)
|
||||
|
||||
def _current_tokens(self):
|
||||
|
||||
toks = self.tokens
|
||||
toks = self.state.tokens
|
||||
# very first infer: duplicate start of seq to beam_size
|
||||
if toks[0].shape[0] == 1:
|
||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0)
|
||||
|
||||
if not self.context.is_empty():
|
||||
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||
if not self.state.context.is_empty():
|
||||
context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||
toks = [context_toks] + toks
|
||||
|
||||
# make it one tensor
|
||||
@@ -301,7 +313,7 @@ class PaddedAlignAttWhisper:
|
||||
### audio buffer
|
||||
|
||||
def segments_len(self):
|
||||
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
||||
segments_len = sum(s.shape[0] for s in self.state.segments) / 16000
|
||||
return segments_len
|
||||
|
||||
def _apply_minseglen(self):
|
||||
@@ -314,42 +326,36 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
def insert_audio(self, segment=None):
|
||||
if segment is not None:
|
||||
self.segments.append(segment)
|
||||
self.state.segments.append(segment)
|
||||
|
||||
removed_len = 0
|
||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||
segments_len = self.segments_len()
|
||||
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.segments[0].shape[0] / 16000
|
||||
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.state.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||
if len(self.tokens) > 1:
|
||||
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||
self.state.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||
self.state.segments = self.state.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
|
||||
if len(self.state.tokens) > 1:
|
||||
self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist())
|
||||
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
def _clean_cache(self):
|
||||
'''clean the cache that stores the attention matrices and kv_cache.
|
||||
It must be called every time after generation with the model.'''
|
||||
# cleaning cache
|
||||
self.dec_attns = []
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam":
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
self.token_decoder.reset()
|
||||
"""Clean the kv_cache after each inference step."""
|
||||
self.state.clean_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def lang_id(self, encoder_features):
|
||||
"""Language detection from encoder features.
|
||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language.
|
||||
"""
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = encoder_features.shape[0]
|
||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||
# Note: don't use kv_cache for language detection
|
||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
@@ -379,29 +385,42 @@ class PaddedAlignAttWhisper:
|
||||
@torch.no_grad()
|
||||
def infer(self, is_last=False):
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
if len(self.state.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
input_segments = torch.cat(self.state.segments, dim=0)
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
if len(self.state.segments) > 1:
|
||||
input_segments = torch.cat(self.state.segments, dim=0)
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
input_segments = self.state.segments[0]
|
||||
|
||||
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
|
||||
# logger.debug("Resetting tokenizer to auto for new sentence.")
|
||||
# self.create_tokenizer(None)
|
||||
# self.detected_language = None
|
||||
# self.init_tokens()
|
||||
# self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.use_mlcore:
|
||||
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments,
|
||||
n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES,
|
||||
device="cpu",
|
||||
).unsqueeze(0)
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||
mel_np = np.ascontiguousarray(mel.numpy())
|
||||
ml_inputs = {coreml_input_name or "mel": mel_np}
|
||||
coreml_outputs = coreml_encoder.predict(ml_inputs)
|
||||
if coreml_output_name and coreml_output_name in coreml_outputs:
|
||||
encoder_feature_np = coreml_outputs[coreml_output_name]
|
||||
else:
|
||||
encoder_feature_np = next(iter(coreml_outputs.values()))
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_np),
|
||||
device=self.device,
|
||||
)
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
@@ -432,18 +451,18 @@ class PaddedAlignAttWhisper:
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
self.state.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
|
||||
self.trim_context()
|
||||
@@ -463,92 +482,80 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
l_absolute_timestamps = []
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
accumulated_cross_attns = []
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
|
||||
if new_segment:
|
||||
tokens_for_logits = current_tokens
|
||||
else:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens_for_logits = current_tokens[:,-1:]
|
||||
tokens_for_logits = current_tokens[:, -1:]
|
||||
|
||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||
# Get logits and cross-attention weights from decoder
|
||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||
logits, cross_attns = result
|
||||
|
||||
# Accumulate cross-attention from this forward pass
|
||||
accumulated_cross_attns.append(cross_attns)
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
|
||||
# supress blank tokens only at the beginning of the segment
|
||||
# suppress blank tokens only at the beginning of the segment
|
||||
if new_segment:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
new_segment = False
|
||||
self.suppress_tokens(logits)
|
||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
self.state.suppress_tokens_fn(logits)
|
||||
current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
||||
if len(align_heads_in_layer) == 0:
|
||||
continue
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a.unsqueeze(0)
|
||||
else:
|
||||
a = attn_mat[:, head_id, :, :]
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
t = torch.cat(mat, dim=1)
|
||||
tmp.append(t)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||
# Process accumulated cross-attention weights for alignment
|
||||
attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len)
|
||||
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1)
|
||||
|
||||
# Calculate absolute timestamps accounting for cumulative offset
|
||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||
absolute_timestamps = [
|
||||
(frame * 0.02 + self.state.cumulative_time_offset)
|
||||
for frame in most_attended_frames.tolist()
|
||||
]
|
||||
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||
|
||||
logger.debug("current tokens" + str(current_tokens.shape))
|
||||
if completed:
|
||||
# # stripping the last token, the eot
|
||||
# stripping the last token, the eot
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# for some rare cases where the attention fails
|
||||
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||
# TODO: check this
|
||||
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
||||
logger.debug("ommit rewinding from special tokens")
|
||||
self.last_attend_frame = most_attended_frame
|
||||
logger.debug("omit rewinding from special tokens")
|
||||
self.state.last_attend_frame = most_attended_frame
|
||||
else:
|
||||
logger.debug(
|
||||
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
||||
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
||||
f"last attention pos: {self.state.last_attend_frame}; omit this segment")
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
|
||||
break
|
||||
else:
|
||||
self.last_attend_frame = most_attended_frame
|
||||
self.state.last_attend_frame = most_attended_frame
|
||||
|
||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||
@@ -568,12 +575,12 @@ class PaddedAlignAttWhisper:
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
|
||||
# Prepend pending tokens from previous chunk if any
|
||||
if self.pending_incomplete_tokens:
|
||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||
if self.state.pending_incomplete_tokens:
|
||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}")
|
||||
pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
if fire_detected or is_last:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
@@ -584,20 +591,18 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
new_hypothesis = []
|
||||
|
||||
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
self.state.tokens.append(new_tokens)
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||
self.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
||||
self.state.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
@@ -616,21 +621,85 @@ class PaddedAlignAttWhisper:
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=current_timestamp,
|
||||
end=current_timestamp + 0.1,
|
||||
text= word,
|
||||
probability=0.95,
|
||||
speaker=self.speaker,
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
start=round(current_timestamp, 2),
|
||||
end=round(current_timestamp + 0.1, 2),
|
||||
text=word,
|
||||
speaker=self.state.speaker,
|
||||
detected_language=self.state.detected_language
|
||||
).with_offset(
|
||||
self.state.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
# Hold incomplete tokens for next chunk
|
||||
self.pending_incomplete_tokens = []
|
||||
self.state.pending_incomplete_tokens = []
|
||||
if split_words and replacement_char in split_words[-1]:
|
||||
self.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}")
|
||||
|
||||
return timestamped_words
|
||||
|
||||
def _process_cross_attention(
|
||||
self,
|
||||
cross_attns: List[torch.Tensor],
|
||||
content_mel_len: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process cross-attention weights from decoder layers for alignment.
|
||||
|
||||
Args:
|
||||
cross_attns: List of cross-attention tensors from each decoder layer.
|
||||
Each tensor has shape (batch, n_head, seq_len, audio_len)
|
||||
content_mel_len: Length of actual audio content in mel frames
|
||||
|
||||
Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len)
|
||||
"""
|
||||
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||
num_decoder_layers = len(self.model.decoder.blocks)
|
||||
|
||||
if cross_attns and isinstance(cross_attns[0], list):
|
||||
flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list]
|
||||
else:
|
||||
flattened_attns = cross_attns
|
||||
|
||||
for idx, attn_mat in enumerate(flattened_attns):
|
||||
layer_rank = idx % num_decoder_layers
|
||||
# attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1
|
||||
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||
if len(align_heads_in_layer) == 0:
|
||||
continue
|
||||
|
||||
attn_mat = F.softmax(attn_mat, dim=-1)
|
||||
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
# (n_head, seq_len, audio_len) when squeezed
|
||||
if attn_mat.dim() == 4:
|
||||
a = attn_mat[0, head_id, :, :] # (seq_len, audio_len)
|
||||
else:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a.unsqueeze(0) # (1, seq_len, audio_len)
|
||||
else:
|
||||
# attn_mat: (batch, n_head, seq_len, audio_len)
|
||||
a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len)
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
if mat:
|
||||
t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len)
|
||||
tmp.append(t)
|
||||
|
||||
if not tmp:
|
||||
return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device)
|
||||
|
||||
# stck al heads: (batch, num_align_heads, seq_len, audio_len)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||
return attn_of_alignment_heads
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import torch
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
@@ -8,22 +8,19 @@ def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedText:
|
||||
class Timed:
|
||||
start: Optional[float] = 0
|
||||
end: Optional[float] = 0
|
||||
|
||||
@dataclass
|
||||
class TimedText(Timed):
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
is_dummy: Optional[bool] = False
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def is_punctuation(self):
|
||||
return self.text.strip() in PUNCTUATION_MARKS
|
||||
|
||||
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||
return not (self.end <= other.start or other.end <= self.start)
|
||||
def has_punctuation(self) -> bool:
|
||||
return any(char in PUNCTUATION_MARKS for char in self.text.strip())
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
@@ -31,27 +28,25 @@ class TimedText:
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_time(self, time: float) -> bool:
|
||||
return self.start <= time <= self.end
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.text)
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
corrected_speaker: Optional[int] = -1
|
||||
validated_speaker: bool = False
|
||||
validated_text: bool = False
|
||||
validated_language: bool = False
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
@@ -70,68 +65,94 @@ class Transcript(TimedText):
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> "Transcript":
|
||||
"""Collapse multiple ASR tokens into a single transcript span."""
|
||||
sep = sep if sep is not None else ' '
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return cls(start, end, text, probability=probability)
|
||||
return cls(start, end, text)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
class SpeakerSegment(Timed):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
speaker: Optional[int] = -1
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
def approximate_cut_at(self, cut_time):
|
||||
"""
|
||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||
"""
|
||||
if not self.text or not self.contains_time(cut_time):
|
||||
return self, None
|
||||
|
||||
words = self.text.split()
|
||||
num_words = len(words)
|
||||
if num_words == 0:
|
||||
return self, None
|
||||
|
||||
duration_per_word = self.duration() / num_words
|
||||
|
||||
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||
|
||||
if cut_word_index >= num_words:
|
||||
cut_word_index = num_words -1
|
||||
|
||||
text0 = " ".join(words[:cut_word_index])
|
||||
text1 = " ".join(words[cut_word_index:])
|
||||
|
||||
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||
|
||||
return segment0, segment1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
|
||||
start: Optional[float] = None
|
||||
end: Optional[float] = None
|
||||
duration: Optional[float] = None
|
||||
is_starting: bool = False
|
||||
has_ended: bool = False
|
||||
|
||||
def compute_duration(self) -> Optional[float]:
|
||||
if self.start is None or self.end is None:
|
||||
return None
|
||||
self.duration = self.end - self.start
|
||||
return self.duration
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment(TimedText):
|
||||
"""Generic contiguous span built from tokens or silence markers."""
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[Union[ASRToken, Silence]],
|
||||
is_silence: bool = False
|
||||
) -> Optional["Segment"]:
|
||||
"""Return a normalized segment representing the provided tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
start_token = tokens[0]
|
||||
end_token = tokens[-1]
|
||||
if is_silence:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=None,
|
||||
speaker=-2
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=''.join(token.text for token in tokens),
|
||||
speaker=-1,
|
||||
detected_language=start_token.detected_language
|
||||
)
|
||||
def is_silence(self) -> bool:
|
||||
"""True when this segment represents a silence gap."""
|
||||
return self.speaker == -2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the line for frontend consumption."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'start': format_time(self.start),
|
||||
@@ -143,6 +164,33 @@ class Line(TimedText):
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
|
||||
"""Populate line attributes from a contiguous token list."""
|
||||
self.text = ''.join([token.text for token in tokens])
|
||||
self.start = tokens[0].start
|
||||
self.end = tokens[-1].end
|
||||
self.speaker = 1
|
||||
self.detected_language = tokens[0].detected_language
|
||||
return self
|
||||
|
||||
def build_from_segment(self, segment: Segment) -> "Line":
|
||||
"""Populate the line fields from a pre-built segment."""
|
||||
self.text = segment.text
|
||||
self.start = segment.start
|
||||
self.end = segment.end
|
||||
self.speaker = segment.speaker
|
||||
self.detected_language = segment.detected_language
|
||||
return self
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
|
||||
class SilentLine(Line):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.speaker = -2
|
||||
self.text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
@@ -155,8 +203,9 @@ class FrontData():
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the front-end data payload."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
@@ -176,14 +225,22 @@ class ChangeSpeaker:
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list = field(default_factory=list)
|
||||
last_validated_token: int = 0
|
||||
last_punctuation_index: Optional[int] = None
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
buffer_transcription: str = field(default_factory=Transcript)
|
||||
"""Unified state class for audio processing.
|
||||
|
||||
Contains both persistent state (tokens, buffers) and temporary update buffers
|
||||
(new_* fields) that are consumed by TokensAlignment.
|
||||
"""
|
||||
# Persistent state
|
||||
tokens: List[ASRToken] = field(default_factory=list)
|
||||
buffer_transcription: Transcript = field(default_factory=Transcript)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
beg_loop: Optional[int] = None
|
||||
|
||||
# Temporary update buffers (consumed by TokensAlignment.update())
|
||||
new_tokens: List[Union[ASRToken, Silence]] = field(default_factory=list)
|
||||
new_translation: List[Any] = field(default_factory=list)
|
||||
new_diarization: List[Any] = field(default_factory=list)
|
||||
new_tokens_buffer: List[Any] = field(default_factory=list) # only when local agreement
|
||||
new_translation_buffer= TimedText()
|
||||
179
whisperlivekit/tokens_alignment.py
Normal file
179
whisperlivekit/tokens_alignment.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from time import time
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
|
||||
SilentLine, SpeakerSegment,
|
||||
TimedText)
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index: int = 0
|
||||
self._diarization_index: int = 0
|
||||
self._translation_index: int = 0
|
||||
|
||||
self.all_tokens: List[ASRToken] = []
|
||||
self.all_diarization_segments: List[SpeakerSegment] = []
|
||||
self.all_translation_segments: List[Any] = []
|
||||
|
||||
self.new_tokens: List[ASRToken] = []
|
||||
self.new_diarization: List[SpeakerSegment] = []
|
||||
self.new_translation: List[Any] = []
|
||||
self.new_translation_buffer: Union[TimedText, str] = TimedText()
|
||||
self.new_tokens_buffer: List[Any] = []
|
||||
self.sep: str = sep if sep is not None else ' '
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
||||
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
||||
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
||||
|
||||
self.all_tokens.extend(self.new_tokens)
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
self.new_translation_buffer = self.state.new_translation_buffer
|
||||
|
||||
def add_translation(self, line: Line) -> None:
|
||||
"""Append translated text segments that overlap with a line."""
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + (self.sep if ts.text else '')
|
||||
elif line.translation:
|
||||
break
|
||||
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
|
||||
"""Group tokens into segments split by punctuation and explicit silence."""
|
||||
segments = []
|
||||
segment_start_idx = 0
|
||||
for i, token in enumerate(self.all_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
segments.append(previous_segment)
|
||||
segment = Segment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
final_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx:],
|
||||
)
|
||||
if final_segment:
|
||||
segments.append(final_segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
||||
"""Merge consecutive diarization slices that share the same speaker."""
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
|
||||
|
||||
@staticmethod
|
||||
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
|
||||
"""Return the overlap duration between two timed segments."""
|
||||
start = max(seg1.start, seg2.start)
|
||||
end = min(seg1.end, seg2.end)
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Line], str]:
|
||||
"""Build lines when diarization is enabled and track overflow buffer."""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
for punctuation_segment in punctuation_segments:
|
||||
if not punctuation_segment.is_silence():
|
||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
||||
diarization_buffer += punctuation_segment.text
|
||||
else:
|
||||
max_overlap = 0.0
|
||||
max_overlap_speaker = 1
|
||||
for diarization_segment in diarization_segments:
|
||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
||||
if intersec > max_overlap:
|
||||
max_overlap = intersec
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
|
||||
lines = []
|
||||
if punctuation_segments:
|
||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == lines[-1].speaker:
|
||||
if lines[-1].text:
|
||||
lines[-1].text += segment.text
|
||||
lines[-1].end = segment.end
|
||||
else:
|
||||
lines.append(Line().build_from_segment(segment))
|
||||
|
||||
return lines, diarization_buffer
|
||||
|
||||
|
||||
def get_lines(
|
||||
self,
|
||||
diarization: bool = False,
|
||||
translation: bool = False,
|
||||
current_silence: Optional[Silence] = None
|
||||
) -> Tuple[List[Line], str, Union[str, TimedText]]:
|
||||
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
|
||||
if diarization:
|
||||
lines, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
if token.is_silence():
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
if translation:
|
||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
||||
return lines, diarization_buffer, self.new_translation_buffer.text
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
@@ -7,6 +7,7 @@ def load_file(warmup_file=None, timeout=5):
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
|
||||
if warmup_file == "":
|
||||
|
||||
@@ -391,12 +391,11 @@ function renderLinesWithBuffer(
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
|
||||
if (buffer_diarization && remaining_time_diarization) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import importlib.resources as resources
|
||||
import base64
|
||||
import importlib.resources as resources
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,11 +96,13 @@ def get_inline_ui_html():
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import pathlib
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -4,17 +4,20 @@ import json
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union, Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .version import __version__
|
||||
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
|
||||
pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||
decode, detect_language)
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
@@ -233,13 +236,97 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
|
||||
return converted if converted else state_dict
|
||||
|
||||
|
||||
def _load_lora_state(lora_path: str):
|
||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||
if os.path.isfile(safe_path):
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
|
||||
) from exc
|
||||
return load_file(safe_path)
|
||||
if os.path.isfile(bin_path):
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
|
||||
)
|
||||
|
||||
|
||||
def _collapse_hf_module_name(module: str):
|
||||
if module.startswith("base_model."):
|
||||
module = module[len("base_model.") :]
|
||||
if module.startswith("model.model."):
|
||||
module = module[len("model.") :]
|
||||
if not module.startswith("model."):
|
||||
module = f"model.{module}"
|
||||
return module
|
||||
|
||||
|
||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
config = json.load(handle)
|
||||
if config.get("peft_type") != "LORA":
|
||||
raise ValueError("Only LoRA adapters are supported.")
|
||||
|
||||
r = config.get("r")
|
||||
alpha = config.get("lora_alpha") or config.get("alpha")
|
||||
if not r or not alpha:
|
||||
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
|
||||
scaling = alpha / r
|
||||
|
||||
adapter_state = _load_lora_state(lora_path)
|
||||
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||
for key, tensor in adapter_state.items():
|
||||
if key.endswith("lora_A.weight"):
|
||||
module = key[: -len(".lora_A.weight")]
|
||||
lora_layers.setdefault(module, {})["A"] = tensor
|
||||
elif key.endswith("lora_B.weight"):
|
||||
module = key[: -len(".lora_B.weight")]
|
||||
lora_layers.setdefault(module, {})["B"] = tensor
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError(f"No LoRA tensors found in {lora_path}")
|
||||
|
||||
for module, parts in lora_layers.items():
|
||||
if "A" not in parts or "B" not in parts:
|
||||
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
|
||||
|
||||
hf_module = _collapse_hf_module_name(module)
|
||||
hf_weight_key = f"{hf_module}.weight"
|
||||
|
||||
delta = parts["B"] @ parts["A"]
|
||||
delta = delta * scaling
|
||||
|
||||
converted = _convert_hf_state_dict({hf_weight_key: delta})
|
||||
if not converted:
|
||||
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
|
||||
target_name, delta_tensor = next(iter(converted.items()))
|
||||
if target_name not in state_dict:
|
||||
raise KeyError(
|
||||
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
|
||||
)
|
||||
|
||||
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
|
||||
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only=False,
|
||||
custom_alignment_heads=None
|
||||
decoder_only: bool = False,
|
||||
custom_alignment_heads: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@@ -255,6 +342,8 @@ def load_model(
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
lora_path: str
|
||||
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -302,6 +391,7 @@ def load_model(
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
state_dict = _convert_hf_state_dict(state_dict)
|
||||
_apply_lora_adapter(state_dict, lora_path)
|
||||
|
||||
if dims_cfg is not None:
|
||||
dims = ModelDimensions(**dims_cfg)
|
||||
@@ -329,3 +419,47 @@ def load_model(
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def convert_encoder_to_coreml(
|
||||
model_name = "base",
|
||||
output_path= "whisper_encoder.mlpackage",
|
||||
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
|
||||
precision = "float16",
|
||||
):
|
||||
|
||||
import coremltools as ct
|
||||
model = load_model(model_name, device="cpu", decoder_only=False)
|
||||
encoder = model.encoder.eval().cpu()
|
||||
|
||||
dummy_input = torch.randn(
|
||||
1,
|
||||
model.dims.n_mels,
|
||||
dummy_frames,
|
||||
dtype=next(encoder.parameters()).dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_encoder = torch.jit.trace(encoder, dummy_input)
|
||||
|
||||
precision_map = {
|
||||
"float16": ct.precision.FLOAT16,
|
||||
"fp16": ct.precision.FLOAT16,
|
||||
"float32": ct.precision.FLOAT32,
|
||||
"fp32": ct.precision.FLOAT32,
|
||||
}
|
||||
coreml_precision = precision_map[precision.lower()]
|
||||
|
||||
mlmodel = ct.convert(
|
||||
traced_encoder,
|
||||
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
|
||||
convert_to= "mlprogram",
|
||||
compute_precision=coreml_precision,
|
||||
)
|
||||
|
||||
output_path = Path(output_path)
|
||||
mlmodel.save(str(output_path))
|
||||
return output_path
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
Tuple, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -146,16 +147,13 @@ class PyTorchInference(Inference):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||
self.kv_modules = key_modules + value_modules
|
||||
self.kv_cache_ids = []
|
||||
for block in self.model.decoder.blocks:
|
||||
self.kv_cache_ids.append(block.attn.key_cache_id)
|
||||
self.kv_cache_ids.append(block.attn.value_cache_id)
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
if not self.kv_cache:
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
@@ -163,17 +161,14 @@ class PyTorchInference(Inference):
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
for module in self.kv_modules:
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
||||
for cache_id in self.kv_cache_ids:
|
||||
if cache_id in self.kv_cache:
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
|
||||
@@ -79,18 +79,23 @@ def disable_sdpa():
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||
use_sdpa = False # Disable SDPA to ensure qk is always computed when needed
|
||||
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.n_text_ctx = n_text_ctx
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
self.cache_id = cache_id
|
||||
self.key.cache_id = f"{cache_id}_key"
|
||||
self.value.cache_id = f"{cache_id}_value"
|
||||
# Cache IDs for key and value (used with dict-based kv_cache)
|
||||
self.key_cache_id = f"{cache_id}_key"
|
||||
self.value_cache_id = f"{cache_id}_value"
|
||||
# Keep these for backward compatibility with hook-based caching
|
||||
self.key.cache_id = self.key_cache_id
|
||||
self.value.cache_id = self.value_cache_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -101,19 +106,45 @@ class MultiHeadAttention(nn.Module):
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
if xa is None:
|
||||
# Self-attention
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
if kv_cache is not None:
|
||||
k, v = self._update_self_attn_cache(k, v, kv_cache)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
# Cross-attention: compute once and cache, or reuse from cache
|
||||
if kv_cache is not None and self.key_cache_id in kv_cache:
|
||||
k = kv_cache[self.key_cache_id]
|
||||
v = kv_cache[self.value_cache_id]
|
||||
else:
|
||||
k = self.key(xa)
|
||||
v = self.value(xa)
|
||||
if kv_cache is not None:
|
||||
kv_cache[self.key_cache_id] = k
|
||||
kv_cache[self.value_cache_id] = v
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def _update_self_attn_cache(
|
||||
self, k: Tensor, v: Tensor, kv_cache: dict
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Update self-attention kv cache by concatenating new k,v with cached values."""
|
||||
if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx:
|
||||
# First token or context overflow: save as-is
|
||||
kv_cache[self.key_cache_id] = k.detach()
|
||||
kv_cache[self.value_cache_id] = v.detach()
|
||||
else:
|
||||
# Concatenate with existing cache
|
||||
cached_k = kv_cache[self.key_cache_id]
|
||||
cached_v = kv_cache[self.value_cache_id]
|
||||
k = torch.cat([cached_k, k], dim=1).detach()
|
||||
v = torch.cat([cached_v, v], dim=1).detach()
|
||||
kv_cache[self.key_cache_id] = k
|
||||
kv_cache[self.value_cache_id] = v
|
||||
return k, v
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@@ -143,14 +174,21 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||
def __init__(
|
||||
self, n_state: int, n_head: int, cross_attention: bool = False,
|
||||
cache_id: str = "", n_text_ctx: int = 448
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||
self.attn = MultiHeadAttention(
|
||||
n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx
|
||||
)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||
MultiHeadAttention(
|
||||
n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx
|
||||
) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
@@ -166,12 +204,21 @@ class ResidualAttentionBlock(nn.Module):
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Returns:
|
||||
x: The output tensor
|
||||
cross_attn_qk: Cross-attention weights (if cross_attn exists), else None
|
||||
"""
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
cross_attn_qk = None
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
cross_out, cross_attn_qk = self.cross_attn(
|
||||
self.cross_attn_ln(x), xa, kv_cache=kv_cache
|
||||
)
|
||||
x = x + cross_out
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
return x, cross_attn_qk
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
@@ -201,7 +248,7 @@ class AudioEncoder(nn.Module):
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x, _ = block(x) # Encoder blocks don't have cross-attention
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
@@ -212,13 +259,17 @@ class TextDecoder(nn.Module):
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
||||
ResidualAttentionBlock(
|
||||
n_state, n_head, cross_attention=True,
|
||||
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
@@ -227,28 +278,57 @@ class TextDecoder(nn.Module):
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Tensor,
|
||||
kv_cache: Optional[dict] = None,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
kv_cache : Optional[dict]
|
||||
Dictionary to store/retrieve key-value cache for efficient decoding
|
||||
return_cross_attn : bool
|
||||
If True, return cross-attention weights from all decoder layers
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits : Tensor
|
||||
The output logits
|
||||
cross_attns : Optional[List[Tensor]]
|
||||
List of cross-attention weights per layer (only if return_cross_attn=True)
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
# Calculate offset from self-attention cache (not cross-attention which has audio length)
|
||||
offset = 0
|
||||
if kv_cache:
|
||||
# Use the first decoder block's self-attention key cache to get token position
|
||||
first_self_attn_key = self.blocks[0].attn.key_cache_id
|
||||
if first_self_attn_key in kv_cache:
|
||||
offset = kv_cache[first_self_attn_key].shape[1]
|
||||
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
cross_attns = [] if return_cross_attn else None
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
if return_cross_attn and cross_attn_qk is not None:
|
||||
cross_attns.append(cross_attn_qk)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
if return_cross_attn:
|
||||
return logits, cross_attns
|
||||
return logits
|
||||
|
||||
|
||||
@@ -292,8 +372,18 @@ class Whisper(nn.Module):
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
def logits(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
audio_features: torch.Tensor,
|
||||
kv_cache: Optional[dict] = None,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
return self.decoder(
|
||||
tokens, audio_features,
|
||||
kv_cache=kv_cache,
|
||||
return_cross_attn=return_cross_attn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
@@ -312,39 +402,6 @@ class Whisper(nn.Module):
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
|
||||
@@ -8,28 +8,13 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
|
||||
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_end,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
from .utils import (exact_div, format_timestamp, get_end, get_writer,
|
||||
make_safe, optional_float, optional_int, str2bool)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
Reference in New Issue
Block a user