16 Commits

Author SHA1 Message Date
Quentin Fuxa
3bd2122eb4 0.2.8 : only the decoder of whisper is loaded in memory when a different encoder is used 2025-09-02 21:12:25 +02:00
Quentin Fuxa
50b0527858 update architecture 2025-09-01 21:24:12 +02:00
Quentin Fuxa
b044fcdec2 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-09-01 14:55:19 +02:00
Quentin Fuxa
b0508fcf2c mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 14:55:11 +02:00
Quentin Fuxa
ce89b0aebc Merge pull request #177 from komiyamma/translate-readme-to-japanese
Translate README.md to Japanese
2025-09-01 13:54:50 +02:00
Quentin Fuxa
d5008ed828 mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 12:33:19 +02:00
Quentin Fuxa
d467716e26 add microphone picker 2025-08-31 10:12:52 +02:00
Quentin Fuxa
199e21b3ef faster-whisper as an optional encoder alternative for simulstreaming 2025-08-30 23:50:16 +02:00
Quentin Fuxa
1d926f2e67 mlx-whisper used as simulstreaming encoder: improve speed for macos systems 2025-08-30 22:19:11 +02:00
Quentin Fuxa
4a71a391b8 get_web_interface_html to get_inline_ui_html for embedded web interface HTML 2025-08-30 13:44:06 +02:00
google-labs-jules[bot]
d3ed4e46e2 Translate README.md to Japanese
Create a Japanese version of the README.md file named ReadmeJP.md.
This makes the project more accessible to Japanese-speaking users.
2025-08-30 04:16:18 +00:00
Quentin Fuxa
057a1026d7 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-29 22:01:04 +02:00
Quentin Fuxa
1ba171a58d add embedded web interface HTML (single-file version with inline CSS/JS/SVG)
### Added
- `get_inline_ui_html()`: generates a self-contained version of the web interface, with CSS, JS, and SVG assets inlined directly into the HTML. useful for environments where serving static files is inconvenient or when a single-call UI delivery is preferred.

(cherry picked from commit aa44a92a67)
2025-08-29 22:00:59 +02:00
Quentin Fuxa
1adac67155 explanations about model persistency in containers 2025-08-29 21:27:08 +02:00
Quentin Fuxa
42be1a3773 Merge pull request #173 from CoderRahul9904/chore/docker/pytorch-timeout-retries
fix: increase pip timeout & retries for torch wheel install
2025-08-29 21:22:30 +02:00
Rahul Mourya
0a49fafa0d Update Dockerfile
fix(docker): increase pip timeout/retries for PyTorch wheel installs
2025-08-30 00:23:59 +05:30
24 changed files with 822 additions and 831 deletions

70
DEV_NOTES.md Normal file
View File

@@ -0,0 +1,70 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -17,18 +17,26 @@ RUN apt-get update && \
ffmpeg \
git \
build-essential \
python3-dev && \
python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
# Note: For gates models, need to add your HF toke. See README.md
# for more details.
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
@@ -37,16 +45,14 @@ RUN if [ -n "$EXTRAS" ]; then \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models by:
# Note: If running multiple containers, better to map a shared
# bucket.
#
# In-container caching for Hugging Face models by:
# A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is
# only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"]
# or
# B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg.
@@ -61,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
@@ -70,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args
CMD ["--model", "medium"]
CMD ["--model", "medium"]

View File

@@ -9,7 +9,7 @@
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
@@ -67,10 +67,10 @@ pip install whisperlivekit
| Optional | `pip install` |
|-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` |
| Original Whisper backend | `whisper` |
| Improved timestamps backend | `whisper-timestamped` |
| Apple Silicon optimization backend | `mlx-whisper` |
| **Apple Silicon optimized backend** | `mlx-whisper` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them.
@@ -128,7 +128,7 @@ async def websocket_endpoint(websocket: WebSocket):
await audio_processor.process_audio(message)
```
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
## Parameters & Configuration
@@ -138,6 +138,7 @@ An important list of parameters can be changed. But what *should* you change?
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--task translate`, to translate in english
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
@@ -159,14 +160,9 @@ The rest I don't recommend. But below are your options.
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -180,6 +176,12 @@ The rest I don't recommend. But below are your options.
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` |

258
ReadmeJP.md Normal file
View File

@@ -0,0 +1,258 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 388 KiB

After

Width:  |  Height:  |  Size: 368 KiB

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 423 KiB

After

Width:  |  Height:  |  Size: 449 KiB

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.7"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
version = "0.2.8"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
@@ -18,6 +18,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
@@ -28,7 +33,8 @@ dependencies = [
"faster-whisper",
"uvicorn",
"websockets",
"torch",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'

View File

@@ -19,6 +19,15 @@ transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
#to remove after 0.2.8
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
logger.warning(f"""
{'='*50}
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
{'='*50}
""")
global transcription_engine
transcription_engine = TranscriptionEngine(
**vars(args),

View File

@@ -46,6 +46,7 @@ class TranscriptionEngine:
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
@@ -60,7 +61,7 @@ class TranscriptionEngine:
"diarization_backend": "sortformer",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
@@ -97,7 +98,7 @@ class TranscriptionEngine:
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)

View File

@@ -161,6 +161,14 @@ def parse_args():
# SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument(
"--frame-threshold",

View File

@@ -13,15 +13,25 @@ import os
import gc
logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try:
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
except ImportError as e:
raise ImportError(
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# TOO_MANY_REPETITIONS = 3
@@ -51,7 +61,10 @@ class SimulStreamingOnlineProcessor:
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model)
loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset):
"""
@@ -231,7 +244,8 @@ class SimulStreamingASR():
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
@@ -276,15 +290,44 @@ class SimulStreamingASR():
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):

View File

@@ -0,0 +1,72 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model

View File

@@ -14,7 +14,7 @@ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens,
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from time import time
from .token_buffer import TokenBuffer
import numpy as np
@@ -23,8 +23,22 @@ from .generation_progress import *
DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
import wave
try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
@@ -33,7 +47,13 @@ import wave
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
@@ -42,6 +62,11 @@ class PaddedAlignAttWhisper:
else:
self.model = load_model(name=model_name, download_root=model_path)
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
@@ -151,6 +176,15 @@ class PaddedAlignAttWhisper:
for hook in self.l_hooks:
hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
@@ -359,20 +393,36 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
# encode
encoder_feature = self.model.encoder(mel)
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
device = 'cpu'
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
device = mel.device
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ")
@@ -397,7 +447,7 @@ class PaddedAlignAttWhisper:
####################### Decoding loop
logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
completed = False
attn_of_alignment_heads = None

View File

@@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only=False
) -> Whisper:
"""
Load a Whisper ASR model
@@ -151,7 +152,14 @@ def load_model(
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:

View File

@@ -253,16 +253,18 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
if not decoder_only:
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,

View File

@@ -1,60 +0,0 @@
# gemma_translate.py
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "google/gemma-3-270m-it"
def build_prompt(tokenizer, text, target_lang, source_lang=None):
# Use the model's chat template for best results
if source_lang:
user_msg = (
f"Translate the following {source_lang} text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
else:
user_msg = (
f"Translate the following text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
chat = [{"role": "user", "content": user_msg}]
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
def translate(text, target_lang, source_lang=None, max_new_tokens=256, temperature=0.2, top_p=0.95):
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
prompt = build_prompt(tokenizer, text, target_lang, source_lang)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
eos_token_id=tokenizer.eos_token_id,
)
# Slice off the prompt to keep only the assistant answer
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
out = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return out
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Translate with google/gemma-3-270m-it")
ap.add_argument("--text", required=True, help="Text to translate")
ap.add_argument("--to", dest="target_lang", required=True, help="Target language (e.g., French, Spanish)")
ap.add_argument("--from", dest="source_lang", default=None, help="Source language (optional)")
ap.add_argument("--temp", type=float, default=0.2, help="Sampling temperature (0 = deterministic-ish)")
ap.add_argument("--max-new", type=int, default=256, help="Max new tokens")
args = ap.parse_args()
print(translate(args.text, args.target_lang, args.source_lang, max_new_tokens=args.max_new, temperature=args.temp))

View File

@@ -1,121 +0,0 @@
# nllb_translate.py
import argparse
from pathlib import Path
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = "facebook/nllb-200-distilled-600M"
# Common language shortcuts → NLLB codes (extend as needed)
LANG_MAP = {
"english": "eng_Latn",
"en": "eng_Latn",
"french": "fra_Latn",
"fr": "fra_Latn",
"spanish": "spa_Latn",
"es": "spa_Latn",
"german": "deu_Latn",
"de": "deu_Latn",
"italian": "ita_Latn",
"it": "ita_Latn",
"portuguese": "por_Latn",
"pt": "por_Latn",
"arabic": "arb_Arab",
"ar": "arb_Arab",
"russian": "rus_Cyrl",
"ru": "rus_Cyrl",
"turkish": "tur_Latn",
"tr": "tur_Latn",
"chinese": "zho_Hans",
"zh": "zho_Hans", # Simplified
"zh-cn": "zho_Hans",
"zh-hans": "zho_Hans",
"zh-hant": "zho_Hant", # Traditional
"japanese": "jpn_Jpan",
"ja": "jpn_Jpan",
"korean": "kor_Hang",
"ko": "kor_Hang",
"dutch": "nld_Latn",
"nl": "nld_Latn",
"polish": "pol_Latn",
"pl": "pol_Latn",
"swedish": "swe_Latn",
"sv": "swe_Latn",
"norwegian": "nob_Latn",
"no": "nob_Latn",
"danish": "dan_Latn",
"da": "dan_Latn",
"finnish": "fin_Latn",
"fi": "fin_Latn",
"catalan": "cat_Latn",
"ca": "cat_Latn",
"hindi": "hin_Deva",
"hi": "hin_Deva",
"vietnamese": "vie_Latn",
"vi": "vie_Latn",
"indonesian": "ind_Latn",
"id": "ind_Latn",
"thai": "tha_Thai",
"th": "tha_Thai",
}
def norm_lang(code: str) -> str:
c = code.strip().lower()
return LANG_MAP.get(c, code)
def translate_texts(texts: List[str], src_code: str, tgt_code: str,
max_new_tokens=512, device=None, dtype=None) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, src_lang=src_code)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype if dtype is not None else (torch.float16 if torch.cuda.is_available() else torch.float32),
device_map="auto" if torch.cuda.is_available() else None,
)
if device:
model.to(device)
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
if device or torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
forced_bos = tokenizer.convert_tokens_to_ids(tgt_code)
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
forced_bos_token_id=forced_bos,
)
outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in outs]
def main():
ap = argparse.ArgumentParser(description="Translate with facebook/nllb-200-distilled-600M")
ap.add_argument("--text", help="Inline text to translate")
ap.add_argument("--file", help="Path to a UTF-8 text file (one example per line)")
ap.add_argument("--src", required=True, help="Source language (e.g. fr, fra_Latn)")
ap.add_argument("--tgt", required=True, help="Target language (e.g. en, eng_Latn)")
ap.add_argument("--max-new", type=int, default=512, help="Max new tokens")
args = ap.parse_args()
src = norm_lang(args.src)
tgt = norm_lang(args.tgt)
batch: List[str] = []
if args.text:
batch.append(args.text)
if args.file:
lines = Path(args.file).read_text(encoding="utf-8").splitlines()
batch.extend([ln for ln in lines if ln.strip()])
if not batch:
raise SystemExit("Provide --text or --file")
results = translate_texts(batch, src, tgt, max_new_tokens=args.max_new)
for i, (inp, out) in enumerate(zip(batch, results), 1):
print(f"\n--- Sample {i} ---")
print(f"SRC [{src}]: {inp}")
print(f"TGT [{tgt}]: {out}")
if __name__ == "__main__":
main()

View File

@@ -1,38 +0,0 @@
import regex
from functools import lru_cache
class SentenceSegmenter:
"""
Regex sentence splitter for Latin languages, Japanese and Chinese.
It is based on sacrebleu TokenizerV14International(BaseTokenizer).
Returns: a list of strings, where each string is a sentence.
Spaces following punctuation are appended after punctuation within the sequence.
Total number of characters in the output is the same as in the input.
"""
sep = 'ŽžŽžSentenceSeparatorŽžŽž' # string that certainly won't be in src or target
latin_terminals = '!?.'
jap_zh_terminals = '。!?'
terminals = latin_terminals + jap_zh_terminals
def __init__(self):
# end of sentence characters:
terminals = self.terminals
self._re = [
# Separate out punctuations preceeded by a non-digit.
# If followed by space-like sequence of characters, they are
# appended to the punctuation, not to the next sequence.
(regex.compile(r'(\P{N})(['+terminals+r'])(\p{Z}*)'), r'\1\2\3'+self.sep),
# Separate out punctuations followed by a non-digit
(regex.compile(r'('+terminals+r')(\P{N})'), r'\1'+self.sep+r'\2'),
# # Separate out symbols
# -> no, we don't tokenize but segment the punctuation
# (regex.compile(r'(\p{S})'), r' \1 '),
]
@lru_cache(maxsize=2**16)
def __call__(self, line):
for (_re, repl) in self._re:
line = _re.sub(repl, line)
return [ t for t in line.split(self.sep) if t != '' ]

View File

@@ -1,466 +0,0 @@
import sys
import ctranslate2
import sentencepiece as spm
import transformers
import argparse
def generate_words(sp, step_results):
tokens_buffer = []
for step_result in step_results:
is_new_word = step_result.token.startswith("")
if is_new_word and tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
tokens_buffer = []
tokens_buffer.append(step_result.token_id)
if tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
from sentence_segmenter import SentenceSegmenter
class LLMTranslator:
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None):
self.system_prompt = system_prompt
print("Loading the model...", file=sys.stderr)
self.generator = ctranslate2.Generator("ct2_EuroLLM-9B-Instruct/", device="cuda")
self.sp = spm.SentencePieceProcessor("EuroLLM-9B-Instruct/tokenizer.model")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("EuroLLM-9B-Instruct/")
print("...done", file=sys.stderr)
self.max_context_length = max_context_length
self.max_tokens_to_trim = self.max_context_length - 10
self.len_ratio = len_ratio
# my regex sentence segmenter
self.segmenter = SentenceSegmenter()
# self.max_generation_length = 512
# self.max_prompt_length = context_length - max_generation_length
def start_dialog(self):
return [{'role':'system', 'content': self.system_prompt }]
def build_prompt(self, dialog):
toks = self.tokenizer.apply_chat_template(dialog, tokenize=True, add_generation_prompt=False)
if len(dialog) == 3:
toks = toks[:-2]
print("len toks:", len(toks), file=sys.stderr)
# print(toks, file=sys.stderr)
c = self.tokenizer.convert_ids_to_tokens(toks)
# print(c,file=sys.stderr)
return c
def translate(self, src, tgt_forced=""):
#src, tgt_forced = self.trim(src, tgt_forced)
dialog = self.start_dialog()
dialog += [{'role':'user','content': src}]
if tgt_forced != "":
dialog += [{'role':'assistant','content': tgt_forced}]
prompt_tokens = self.build_prompt(dialog)
if self.len_ratio is not None:
limit_len = int(len(self.tokenizer.encode(src)) * self.len_ratio) + 10
limit_kw = {'max_length': limit_len}
else:
limit_kw = {}
step_results = self.generator.generate_tokens(
prompt_tokens,
**limit_kw,
# end_token=tokenizer.eos_token,
# sampling_temperature=0.6,
# sampling_topk=20,
# sampling_topp=1,
)
res = []
#output_ids = []
for step_result in step_results:
# is_new_word = step_result.token.startswith("▁")
# if is_new_word and output_ids:
# word = self.sp.decode(output_ids)
# print(word, end=" ", flush=True, file=sys.stderr)
# output_ids = []
# output_ids.append(step_result.token_id)
res.append(step_result)
#if output_ids:
# word = self.sp.decode(output_ids)
# print(word, file=sys.stderr)
return self.sp.decode([r.token_id for r in res])
# print(res)
# print([s.token for s in res], file=sys.stderr)
# print([s.token==self.tokenizer.eos_token for s in res], file=sys.stderr)
class ParallelTextBuffer:
def __init__(self, tokenizer, max_tokens, trimming="segments", init_src="", init_tgt=""):
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.src_buffer = [] # list of lists
if init_src:
self.src_buffer.append(init_src)
self.tgt_buffer = [] # list of strings
if init_tgt:
self.tgt_buffer.append(init_tgt)
self.trimming = trimming
if self.trimming == "sentences":
self.segmenter = SentenceSegmenter()
def len_src(self):
return sum(len(t) for t in self.src_buffer) + len(self.src_buffer) - 1
def insert(self, src, tgt):
self.src_buffer.append(src)
self.tgt_buffer.append(tgt)
def insert_src_suffix(self, s):
if self.src_buffer:
self.src_buffer[-1][-1] += s
else:
self.src_buffer.append([s])
def trim_sentences(self):
# src_tok_lens = [len(self.tokenizer.encode(" ".join(b))) for b in self.src_buffer]
# tgt_tok_lens = [len(self.tokenizer.encode(t)) for t in self.tgt_buffer]
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
def trim_sentence(text):
sents = self.segmenter(text)
print("SENTS:", len(sents), sents, file=sys.stderr)
return "".join(sents[1:])
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
nsrc = trim_sentence(src)
ntgt = trim_sentence(tgt)
if not nsrc or not ntgt:
print("src or tgt is empty after trimming.", file=sys.stderr)
print("src: ", src, file=sys.stderr)
print("tgt: ", tgt, file=sys.stderr)
break
src = nsrc
tgt = ntgt
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SRC:", (src,), file=sys.stderr)
print("TRIMMED TGT:", (tgt,), file=sys.stderr)
self.src_buffer = [src.split()]
self.tgt_buffer = [tgt]
return src, tgt
def trim_segments(self):
print("BUFFER:", file=sys.stderr)
for s,t in zip(self.src_buffer, self.tgt_buffer):
print("\t", s,"...",t,file=sys.stderr) #,self.src_buffer, self.tgt_buffer, file=sys.stderr)
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
if len(self.src_buffer) > 1 and len(self.tgt_buffer) > 1:
self.src_buffer.pop(0)
self.tgt_buffer.pop(0)
else:
break
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SEGMENTS SRC:", (src,), file=sys.stderr)
print("TRIMMED SEGMENTS TGT:", (tgt,), file=sys.stderr)
return src, tgt
def trim(self):
if self.trimming == "sentences":
return self.trim_sentences()
return self.trim_segments()
class SimulLLM:
def __init__(self, llmtrans, min_len=0, chunk=1, trimming="sentences", language="ja", init_src="", init_tgt=""):
self.llmtranslator = llmtrans
#self.src_buffer = init_src
#self.confirmed_tgt = init_tgt
self.buffer = ParallelTextBuffer(self.llmtranslator.tokenizer, self.llmtranslator.max_tokens_to_trim, trimming=trimming, init_src=init_src, init_tgt=init_tgt)
self.last_inserted = []
self.last_unconfirmed = ""
self.min_len = min_len
self.step = chunk
self.language = language
if language in ["ja", "zh"]:
self.specific_space = ""
else:
self.specific_space = " "
def insert(self, src):
if isinstance(src, str):
self.last_inserted.append(src)
else:
self.last_inserted += src
def insert_suffix(self, text):
'''
Insert suffix of a word to the last inserted word.
It may be because the word was split to multiple parts in the input, each with different timestamps.
'''
if self.last_inserted:
self.last_inserted[-1] += text
elif self.src_buffer:
self.buffer.insert_src_suffix(text)
else:
# this shouldn't happen
self.last_inserted.append(text)
def trim_longest_common_prefix(self, a,b):
if self.language not in ["ja", "zh"]:
a = a.split()
b = b.split()
i = 0
for i,(x,y) in enumerate(zip(a,b)):
if x != y:
break
if self.language in ["ja", "zh"]:
#print("tady160",(a, b, i), file=sys.stderr)
return a[:i], b[i:]
else:
return " ".join(a[:i]), " ".join(b[i:])
def process_iter(self):
if self.buffer.len_src() + len(self.last_inserted) < self.min_len:
return ""
src, forced_tgt = self.buffer.trim() #llmtranslator.trim(" ".join(self.src_buffer), self.confirmed_tgt)
#self.src_buffer = self.src_buffer.split()
#src = " ".join(self.src_buffer)
confirmed_out = ""
run = False
for i in range(self.step, len(self.last_inserted), self.step):
for w in self.last_inserted[i-self.step:i]:
src += " " + w
run = True
if not run: break
print("SRC",src,file=sys.stderr)
print("FORCED TGT",forced_tgt,file=sys.stderr)
out = self.llmtranslator.translate(src, forced_tgt)
print("OUT",out,file=sys.stderr)
confirmed, unconfirmed = self.trim_longest_common_prefix(self.last_unconfirmed, out)
self.last_unconfirmed = unconfirmed
#print("tady", (self.confirmed_tgt, self.specific_space, confirmed), file=sys.stderr)
if confirmed:
# self.confirmed_tgt += self.specific_space + confirmed
# print(confirmed_out, confirmed, file=sys.stderr)
confirmed_out += self.specific_space + confirmed
print("CONFIRMED NOW:",confirmed,file=sys.stderr)
print(file=sys.stderr)
print(file=sys.stderr)
print("#################",file=sys.stderr)
if run:
self.buffer.insert(self.last_inserted, confirmed_out)
self.last_inserted = []
ret = confirmed_out
print("RET:",ret,file=sys.stderr)
return ret
def finalize(self):
return self.last_unconfirmed
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input-instance', type=str, default=None, help="Filename of instances to simulate input. If not set, txt input is read from stdin.")
#parser.add_argument('--output_instance', type=str, default=None, help="Write output as instance into this file, while also writing to stdout.")
parser.add_argument('--min-chunk-size', type=int, default=1,
help='Minimum number of space-delimited words to process in each LocalAgreement update. The more, the higher quality, but slower.')
parser.add_argument('--min-len', type=int, default=1,
help='Minimum number of space-delimited words at the beginning.')
#parser.add_argument('--start_at', type=int, default=0, help='Skip first N words.')
# maybe later
#parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
#parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
lan_to_name = {
"de": "German",
"ja": "Japanese",
"zh-tr": "Chinese Traditional",
"zh-sim": "Chinese Simplified",
"cs": "Czech",
}
parser.add_argument('--lan', '--language', type=str, default="de",
help="Target language code.",
choices=["de", "ja","zh-tr","zh-sim","cs"])
SrcLang = "English" # always
TgtLang = "German"
default_prompt="You are simultaneous interpreter from {SrcLang} to {TgtLang}. We are at a conference. It is important that you translate " + \
"only what you hear, nothing else!"
parser.add_argument('--sys_prompt', type=str, default=None,
help='System prompt. If None, default one is used, depending on the language. The prompt should ')
default_init = "Please, go ahead, you can start with your presentation, we are ready."
default_inits_tgt = {
'de': "Bitte schön, Sie können mit Ihrer Präsentation beginnen, wir sind bereit.",
'ja': "どうぞ、プレゼンテーションを始めてください。", # # Please go ahead and start your presentation. # this is in English
'zh-tr': "請繼續,您可以開始您的簡報,我們已經準備好了。",
'zh-sim': "请吧,你可以开始发言了,我们已经准备好了。",
'cs': "Prosím, můžete začít s prezentací, jsme připraveni.",
}
parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
'It can be context specific for the given input. Default is ')
parser.add_argument('--init_prompt_tgt', type=str, default=None, help='Init translation with this target. It should be example translation of init_prompt_src. '
' There is default init message, depending on the language.')
parser.add_argument('--len-threshold', type=float, default=None, help='Ratio of the length of the source and generated target, in number of sentencepiece tokens. '
'It should reflect the target language and. If not set, no len-threshold is used.')
# how many times is target text longer than English
lan_thresholds = {
'de': 1.3, # 12751/9817 ... the proportion of subword tokens for ACL6060 dev de vs. en text, for EuroLLM-9B-Instruct tokenizer
'ja': 1.34, # 13187/9817
'zh': 1.23, # 12115/9817
'zh-tr': 1.23, # 12115/9817
'zh-sim': 1.23, # 12115/9817
# 'cs': I don't know # guessed
}
parser.add_argument('--language-specific-len-threshold', default=False, action="store_true",
help='Use language-specific length threshold, e.g. 1.3 for German.')
parser.add_argument("--max-context-length", type=int, default=4096, help="Maximum number of tokens in the model to use.")
parser.add_argument("--buffer_trimming", type=str, default="sentences", choices=["segments","sentences"], help="Buffer trimming strategy.")
args = parser.parse_args()
if args.sys_prompt is None:
TgtLang = lan_to_name[args.lan]
sys_prompt = default_prompt.format(SrcLang=SrcLang, TgtLang=TgtLang)
else:
sys_prompt = args.sys_prompt
if args.init_prompt_src is None:
init_src = default_init.split()
if args.init_prompt_tgt is None:
init_tgt = default_inits_tgt[args.lan]
if args.lan == "ja":
init_src = 'Please go ahead and start your presentation.'.split()
print("WARNING: Default init_prompt_src not set and language is Japanese. The init_src prompt changed to be more verbose.", file=sys.stderr)
else:
print("WARNING: init_prompt_tgt is used, init_prompt_src is None, the default one. It may be wrong!", file=sys.stderr)
init_tgt = args.init_prompt_tgt
else:
init_src = args.init_prompt_src.split()
if args.init_prompt_tgt is None:
print("WARNING: init_prompt_src is used, init_prompt_tgt is None, so the default one is used. It may be wrong!", file=sys.stderr)
init_tgt = default_inits_tgt[args.lan]
else:
init_tgt = args.init_prompt_tgt
print("INFO: System prompt:", sys_prompt, file=sys.stderr)
print("INFO: Init prompt src:", init_src, file=sys.stderr)
print("INFO: Init prompt tgt:", init_tgt, file=sys.stderr)
if args.language_specific_len_threshold:
if args.len_threshold is not None:
print("ERROR: --len-threshold is set, but --language-specific-len-threshold is also set. Only one can be used.", file=sys.stderr)
sys.exit(1)
else:
len_threshold = lan_thresholds[args.lan]
else:
len_threshold = args.len_threshold
llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold)
lan = args.lan if not args.lan.startswith("zh") else "zh"
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
)
# two input options
if args.input_instance is not None:
print("INFO: Reading input from file", args.input_instance, file=sys.stderr)
import json
with open(args.input_instance, "r") as f:
instance = json.load(f)
asr_source = instance["prediction"]
timestamps = instance["delays"]
elapsed = instance["elapsed"]
yield_ts_words = zip(timestamps, timestamps, elapsed, asr_source.split())
else:
print("INFO: Reading stdin in txt format", file=sys.stderr)
def yield_input():
for line in sys.stdin:
line = line.strip()
ts, beg, end, *_ = line.split()
text = line[len(ts)+len(beg)+len(end)+3:]
ts = float(ts)
# in rare cases, the first word is a suffix of the previous word, that was split to multiple parts
if text[0] != " ":
first, *words = text.split()
yield (ts, beg, end, " "+first) # marking the first word with " ", so that it can be later detected and inserted as suffix
else:
words = text.split()
for w in words:
yield (ts, beg, end, w)
yield_ts_words = yield_input()
#i = 0
for t,b,e,w in yield_ts_words:
if w.startswith(" "): # it is suffix of the previous word
w = w[1:]
simul.insert_suffix(w)
continue
simul.insert(w)
out = simul.process_iter()
if out:
print(t,b,e,out,flush=True)
# if i > 50:
# break
# i += 1
out = simul.finalize()
print(t,b,e,out,flush=True)

View File

@@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5):
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
return None
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
return None
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
return None
try:
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
return None
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):

View File

@@ -184,7 +184,7 @@ body {
.settings {
display: flex;
flex-direction: column;
flex-wrap: wrap;
align-items: flex-start;
gap: 12px;
}
@@ -198,23 +198,27 @@ body {
#chunkSelector,
#websocketInput,
#themeSelector {
#themeSelector,
#microphoneSelect {
font-size: 16px;
padding: 5px 8px;
border-radius: 8px;
border: 1px solid var(--border);
background-color: var(--button-bg);
color: var(--text);
max-height: 34px;
max-height: 30px;
}
#websocketInput {
width: 220px;
#microphoneSelect {
width: 100%;
max-width: 190px;
min-width: 120px;
}
#chunkSelector:focus,
#websocketInput:focus,
#themeSelector:focus {
#themeSelector:focus,
#microphoneSelect:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
@@ -247,9 +251,9 @@ label {
}
.theme-selector-container {
position: absolute;
top: 20px;
right: 20px;
display: flex;
align-items: center;
margin-top: 17px;
}
.segmented label {
@@ -400,3 +404,57 @@ label {
font-size: 14px;
margin-bottom: 0px;
}
/* for smaller screens */
@media (max-width: 768px) {
.settings-container {
flex-direction: column;
gap: 10px;
}
.settings {
justify-content: center;
gap: 8px;
}
.field {
align-items: center;
}
#websocketInput,
#microphoneSelect {
min-width: 100px;
max-width: 160px;
}
.theme-selector-container {
margin-top: 10px;
}
}
@media (max-width: 480px) {
body {
margin: 10px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 140px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
}

View File

@@ -1,61 +1,73 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" />
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" />
</head>
<body>
<div class="settings-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
<div class="settings-container">
<button id="recordButton">
<div class="shape-container">
<div class="shape"></div>
</div>
<div class="recording-info">
<div class="wave-container">
<canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
</div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">WebSocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
</div>
</div>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="/web/live_transcription.js"></script>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="/web/live_transcription.js"></script>
</body>
</html>
</html>

View File

@@ -18,6 +18,8 @@ let animationFrame = null;
let waitingForStop = false;
let lastReceivedData = null;
let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
@@ -31,6 +33,7 @@ const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
function getWaveStroke() {
const styles = getComputedStyle(document.documentElement);
@@ -82,6 +85,61 @@ if (darkMq && darkMq.addEventListener) {
darkMq.addListener(handleOsThemeChange);
}
async function enumerateMicrophones() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers
function fmt1(x) {
const n = Number(x);
@@ -377,7 +435,11 @@ async function startRecording() {
console.log("Error acquiring wake lock.");
}
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
@@ -516,3 +578,22 @@ function updateUI() {
}
recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});

View File

@@ -16,32 +16,25 @@ def get_web_interface_html():
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
# Load HTML template
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
# Load CSS and embed it
html_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
# Load JS and embed it
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
# Load SVG files and convert to data URIs
# SVG files
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
# Replace external references with embedded content
# Replace external references
html_content = html_content.replace(
'<link rel="stylesheet" href="/web/live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
@@ -52,7 +45,7 @@ def get_inline_ui_html():
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references with data URIs
# Replace SVG references
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'