mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
16 Commits
translatio
...
0.2.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bd2122eb4 | ||
|
|
50b0527858 | ||
|
|
b044fcdec2 | ||
|
|
b0508fcf2c | ||
|
|
ce89b0aebc | ||
|
|
d5008ed828 | ||
|
|
d467716e26 | ||
|
|
199e21b3ef | ||
|
|
1d926f2e67 | ||
|
|
4a71a391b8 | ||
|
|
d3ed4e46e2 | ||
|
|
057a1026d7 | ||
|
|
1ba171a58d | ||
|
|
1adac67155 | ||
|
|
42be1a3773 | ||
|
|
0a49fafa0d |
70
DEV_NOTES.md
Normal file
70
DEV_NOTES.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# 1. Simulstreaming: Decouple the encoder for faster inference
|
||||
|
||||
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
|
||||
|
||||
On macOS Apple Silicon M4 :
|
||||
|
||||
| Encoder | base.en | small |
|
||||
|--------|---------|-------|
|
||||
| WHISPER (no modification) | 0.35s | 1.09s |
|
||||
| FASTER_WHISPER | 0.4s | 1.20s |
|
||||
| MLX_WHISPER | 0.07s | 0.20s |
|
||||
|
||||
Memory saved by only loading encoder for optimized framework:
|
||||
|
||||
For tiny.en, mlx whisper:
|
||||
Sizes MLX whisper:
|
||||
Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
|
||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
## Problem Statement
|
||||
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
|
||||
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
|
||||
|
||||
#
|
||||
### Initial Setup
|
||||
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
|
||||
|
||||
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
|
||||
```
|
||||
|
||||
- `DS_a_{i}`: Top detected speaker for prediction i
|
||||
- `DS_b_{i}`: Second detected speaker for prediction i
|
||||
- `AS_{i}`: Attributed speaker for prediction i
|
||||
- `GTS_A`: Ground truth speaker A
|
||||
- `GTS_B`: Ground truth speaker B
|
||||
- `DIST(a, b)`: Distance between detected speakers a and b
|
||||
|
||||
3. **Attribution Logic**
|
||||
|
||||
```
|
||||
AS_0 ← A
|
||||
|
||||
AS_1 ← B
|
||||
|
||||
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
|
||||
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
|
||||
# Likely that DS_a_0 = DS_a_1 (same speaker)
|
||||
AS_1 ← A
|
||||
AS_2 ← B
|
||||
|
||||
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
|
||||
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
|
||||
AS_2 ← A
|
||||
|
||||
ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
31
Dockerfile
31
Dockerfile
@@ -17,18 +17,26 @@ RUN apt-get update && \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
python3-dev \
|
||||
ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129
|
||||
|
||||
# timeout/retries for large torch wheels
|
||||
RUN pip3 install --upgrade pip setuptools wheel && \
|
||||
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchaudio \
|
||||
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
||||
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchvision torchaudio)
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
# Note: For gates models, need to add your HF toke. See README.md
|
||||
# for more details.
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
@@ -37,16 +45,14 @@ RUN if [ -n "$EXTRAS" ]; then \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models by:
|
||||
# Note: If running multiple containers, better to map a shared
|
||||
# bucket.
|
||||
#
|
||||
# In-container caching for Hugging Face models by:
|
||||
# A) Make the cache directory persistent via an anonymous volume.
|
||||
# Note: This only persists for a single, named container. This is
|
||||
# only for convenience at de/test stage.
|
||||
# For prod, it is better to use a named volume via host mount/k8s.
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
|
||||
# or
|
||||
# B) Conditionally copy a local pre-cache from the build context to the
|
||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||
@@ -61,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
@@ -70,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args
|
||||
CMD ["--model", "medium"]
|
||||
CMD ["--model", "medium"]
|
||||
|
||||
26
README.md
26
README.md
@@ -9,7 +9,7 @@
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
@@ -67,10 +67,10 @@ pip install whisperlivekit
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Speaker diarization with Diart | `diart` |
|
||||
| Original Whisper backend | `whisper` |
|
||||
| Improved timestamps backend | `whisper-timestamped` |
|
||||
| Apple Silicon optimization backend | `mlx-whisper` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
| OpenAI API backend | `openai` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
@@ -128,7 +128,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
|
||||
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
|
||||
|
||||
|
||||
## Parameters & Configuration
|
||||
@@ -138,6 +138,7 @@ An important list of parameters can be changed. But what *should* you change?
|
||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
|
||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
||||
- `--warmup-file`, if you have one
|
||||
- `--task translate`, to translate in english
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
||||
- `--diarization`, if you want to use it.
|
||||
|
||||
@@ -159,14 +160,9 @@ The rest I don't recommend. But below are your options.
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
|
||||
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -180,6 +176,12 @@ The rest I don't recommend. But below are your options.
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
|
||||
258
ReadmeJP.md
Normal file
258
ReadmeJP.md
Normal file
@@ -0,0 +1,258 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
||||
|
||||
#### 主要な研究による技術:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
||||
|
||||
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
||||
|
||||
### アーキテクチャ
|
||||
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
||||
|
||||
### インストールとクイックスタート
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
||||
>
|
||||
> | OS | インストール方法 |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
||||
|
||||
#### クイックスタート
|
||||
1. **文字起こしサーバーを起動します:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
```
|
||||
|
||||
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
||||
|
||||
|
||||
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
||||
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
||||
|
||||
#### オプションの依存関係
|
||||
|
||||
| オプション | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Diartによる話者ダイアライゼーション | `diart` |
|
||||
| オリジナルのWhisperバックエンド | `whisper` |
|
||||
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
||||
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
||||
| OpenAI APIバックエンド | `openai` |
|
||||
|
||||
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
||||
|
||||
### 使用例
|
||||
|
||||
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
||||
|
||||
```bash
|
||||
# デフォルト(small)より良いモデルを使用
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
# ダイアライゼーションと言語を指定した高度な設定
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
await websocket.accept()
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
||||
|
||||
|
||||
## パラメータと設定
|
||||
|
||||
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
||||
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
||||
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
||||
- `--warmup-file`、もしあれば
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
||||
- `--diarization`、使用したい場合。
|
||||
|
||||
残りは推奨しません。しかし、以下があなたのオプションです。
|
||||
|
||||
| パラメータ | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisperモデルのサイズ。 | `small` |
|
||||
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
||||
| `--task` | `transcribe`または`translate` | `transcribe` |
|
||||
| `--backend` | 処理バックエンド | `simulstreaming` |
|
||||
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
||||
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
||||
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
||||
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
||||
| `--host` | サーバーホストアドレス | `localhost` |
|
||||
| `--port` | サーバーポート | `8000` |
|
||||
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
|
||||
|
||||
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
||||
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
||||
|
||||
|
||||
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
||||
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
||||
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
||||
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
||||
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
||||
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
||||
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
||||
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
||||
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
||||
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
||||
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
||||
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
||||
|
||||
| ダイアライゼーションオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | 話者識別を有効化 | `False` |
|
||||
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
||||
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
||||
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
||||
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
||||
>4. HuggingFaceでログイン: `huggingface-cli login`
|
||||
|
||||
### 🚀 デプロイガイド
|
||||
|
||||
WhisperLiveKitを本番環境にデプロイするには:
|
||||
|
||||
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
||||
```bash
|
||||
pip install uvicorn gunicorn
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
||||
|
||||
3. **Nginx設定** (本番環境で推奨):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
||||
|
||||
## 🐋 Docker
|
||||
|
||||
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
||||
|
||||
### 前提条件
|
||||
- Dockerがシステムにインストールされていること
|
||||
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
||||
|
||||
### クイックスタート
|
||||
|
||||
**GPUアクセラレーション付き (推奨):**
|
||||
```bash
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
**CPUのみ:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### 高度な使用法
|
||||
|
||||
**カスタム設定:**
|
||||
```bash
|
||||
# カスタムモデルと言語の例
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### メモリ要件
|
||||
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
||||
|
||||
|
||||
#### カスタマイズ
|
||||
|
||||
- `--build-arg` オプション:
|
||||
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
||||
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
||||
|
||||
## 🔮 ユースケース
|
||||
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
||||
BIN
architecture.png
BIN
architecture.png
Binary file not shown.
|
Before Width: | Height: | Size: 388 KiB After Width: | Height: | Size: 368 KiB |
BIN
demo.png
BIN
demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 423 KiB After Width: | Height: | Size: 449 KiB |
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.7"
|
||||
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
|
||||
version = "0.2.8"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Quentin Fuxa" }
|
||||
@@ -18,6 +18,11 @@ classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Programming Language :: Python :: 3.15",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||
]
|
||||
@@ -28,7 +33,8 @@ dependencies = [
|
||||
"faster-whisper",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"torch",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
|
||||
@@ -19,6 +19,15 @@ transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
#to remove after 0.2.8
|
||||
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
|
||||
logger.warning(f"""
|
||||
{'='*50}
|
||||
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
|
||||
{'='*50}
|
||||
""")
|
||||
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
|
||||
@@ -46,6 +46,7 @@ class TranscriptionEngine:
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
# simulstreaming params:
|
||||
"disable_fast_encoder": False,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
@@ -60,7 +61,7 @@ class TranscriptionEngine:
|
||||
"diarization_backend": "sortformer",
|
||||
# diart params:
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
@@ -97,7 +98,7 @@ class TranscriptionEngine:
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
|
||||
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
|
||||
if hasattr(self.args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||
|
||||
|
||||
@@ -161,6 +161,14 @@ def parse_args():
|
||||
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--disable-fast-encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="disable_fast_encoder",
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
|
||||
@@ -13,15 +13,25 @@ import os
|
||||
import gc
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
|
||||
try:
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"""SimulStreaming dependencies are not available.
|
||||
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
|
||||
# TOO_MANY_REPETITIONS = 3
|
||||
|
||||
@@ -51,7 +61,10 @@ class SimulStreamingOnlineProcessor:
|
||||
model = self.asr.get_new_model_instance()
|
||||
self.model = PaddedAlignAttWhisper(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=model)
|
||||
loaded_model=model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
@@ -231,7 +244,8 @@ class SimulStreamingASR():
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
self.warmup_file = kwargs.get('warmup_file', None)
|
||||
self.preload_model_count = kwargs.get('preload_model_count', 1)
|
||||
|
||||
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
|
||||
self.fast_encoder = False
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None:
|
||||
@@ -276,15 +290,44 @@ class SimulStreamingASR():
|
||||
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
self.fw_encoder = WhisperModel(
|
||||
self.model_name,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
self.fast_encoder = True
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
|
||||
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = PaddedAlignAttWhisper(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
fw_encoder=self.fw_encoder,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
temp_model.remove_hooks()
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
|
||||
72
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
72
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from mlx_whisper import whisper
|
||||
|
||||
mlx_model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
|
||||
def load_mlx_encoder(
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
# we only want to load the encoder weights here.
|
||||
# Size examples: for tiny.en,
|
||||
# Decoder weights: 59110771 bytes
|
||||
# Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
encoder_weights = {}
|
||||
encoder_weights['encoder'] = weights['encoder']
|
||||
del(weights)
|
||||
|
||||
|
||||
|
||||
model.update(encoder_weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
@@ -14,7 +14,7 @@ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens,
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
import numpy as np
|
||||
@@ -23,8 +23,22 @@ from .generation_progress import *
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import sys
|
||||
import wave
|
||||
|
||||
try:
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
else:
|
||||
try:
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
@@ -33,7 +47,13 @@ import wave
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
loaded_model=None,
|
||||
mlx_encoder=None,
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
@@ -42,6 +62,11 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
self.decode_options = DecodingOptions(
|
||||
@@ -151,6 +176,15 @@ class PaddedAlignAttWhisper:
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
|
||||
def warmup(self, audio):
|
||||
try:
|
||||
self.insert_audio(audio)
|
||||
self.infer(is_last=True)
|
||||
self.refresh_segment(complete=True)
|
||||
logger.info("Model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Model warmup failed: {e}")
|
||||
|
||||
def create_tokenizer(self, language=None):
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=self.tokenizer_is_multilingual,
|
||||
@@ -359,20 +393,36 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
|
||||
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
|
||||
# encode
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||
device = 'cpu'
|
||||
elif self.fw_encoder:
|
||||
audio_length_seconds = len(input_segments) / 16000
|
||||
content_mel_len = int(audio_length_seconds * 100)//2
|
||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
|
||||
device = 'cpu'
|
||||
else:
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
device = mel.device
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# logger.debug("mel ")
|
||||
@@ -397,7 +447,7 @@ class PaddedAlignAttWhisper:
|
||||
####################### Decoding loop
|
||||
logger.info("Decoding loop starts\n")
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
|
||||
completed = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
|
||||
@@ -105,6 +105,7 @@ def load_model(
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only=False
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@@ -151,7 +152,14 @@ def load_model(
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k: v for k, v in checkpoint["model_state_dict"].items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
|
||||
@@ -253,16 +253,18 @@ class TextDecoder(nn.Module):
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
|
||||
if not decoder_only:
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# gemma_translate.py
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "google/gemma-3-270m-it"
|
||||
|
||||
def build_prompt(tokenizer, text, target_lang, source_lang=None):
|
||||
# Use the model's chat template for best results
|
||||
if source_lang:
|
||||
user_msg = (
|
||||
f"Translate the following {source_lang} text into {target_lang}.\n"
|
||||
f"Return only the translation.\n\n"
|
||||
f"Text:\n{text}"
|
||||
)
|
||||
else:
|
||||
user_msg = (
|
||||
f"Translate the following text into {target_lang}.\n"
|
||||
f"Return only the translation.\n\n"
|
||||
f"Text:\n{text}"
|
||||
)
|
||||
chat = [{"role": "user", "content": user_msg}]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
def translate(text, target_lang, source_lang=None, max_new_tokens=256, temperature=0.2, top_p=0.95):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
prompt = build_prompt(tokenizer, text, target_lang, source_lang)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
do_sample=temperature > 0.0,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Slice off the prompt to keep only the assistant answer
|
||||
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
|
||||
out = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
ap = argparse.ArgumentParser(description="Translate with google/gemma-3-270m-it")
|
||||
ap.add_argument("--text", required=True, help="Text to translate")
|
||||
ap.add_argument("--to", dest="target_lang", required=True, help="Target language (e.g., French, Spanish)")
|
||||
ap.add_argument("--from", dest="source_lang", default=None, help="Source language (optional)")
|
||||
ap.add_argument("--temp", type=float, default=0.2, help="Sampling temperature (0 = deterministic-ish)")
|
||||
ap.add_argument("--max-new", type=int, default=256, help="Max new tokens")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(translate(args.text, args.target_lang, args.source_lang, max_new_tokens=args.max_new, temperature=args.temp))
|
||||
@@ -1,121 +0,0 @@
|
||||
# nllb_translate.py
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
MODEL_ID = "facebook/nllb-200-distilled-600M"
|
||||
|
||||
# Common language shortcuts → NLLB codes (extend as needed)
|
||||
LANG_MAP = {
|
||||
"english": "eng_Latn",
|
||||
"en": "eng_Latn",
|
||||
"french": "fra_Latn",
|
||||
"fr": "fra_Latn",
|
||||
"spanish": "spa_Latn",
|
||||
"es": "spa_Latn",
|
||||
"german": "deu_Latn",
|
||||
"de": "deu_Latn",
|
||||
"italian": "ita_Latn",
|
||||
"it": "ita_Latn",
|
||||
"portuguese": "por_Latn",
|
||||
"pt": "por_Latn",
|
||||
"arabic": "arb_Arab",
|
||||
"ar": "arb_Arab",
|
||||
"russian": "rus_Cyrl",
|
||||
"ru": "rus_Cyrl",
|
||||
"turkish": "tur_Latn",
|
||||
"tr": "tur_Latn",
|
||||
"chinese": "zho_Hans",
|
||||
"zh": "zho_Hans", # Simplified
|
||||
"zh-cn": "zho_Hans",
|
||||
"zh-hans": "zho_Hans",
|
||||
"zh-hant": "zho_Hant", # Traditional
|
||||
"japanese": "jpn_Jpan",
|
||||
"ja": "jpn_Jpan",
|
||||
"korean": "kor_Hang",
|
||||
"ko": "kor_Hang",
|
||||
"dutch": "nld_Latn",
|
||||
"nl": "nld_Latn",
|
||||
"polish": "pol_Latn",
|
||||
"pl": "pol_Latn",
|
||||
"swedish": "swe_Latn",
|
||||
"sv": "swe_Latn",
|
||||
"norwegian": "nob_Latn",
|
||||
"no": "nob_Latn",
|
||||
"danish": "dan_Latn",
|
||||
"da": "dan_Latn",
|
||||
"finnish": "fin_Latn",
|
||||
"fi": "fin_Latn",
|
||||
"catalan": "cat_Latn",
|
||||
"ca": "cat_Latn",
|
||||
"hindi": "hin_Deva",
|
||||
"hi": "hin_Deva",
|
||||
"vietnamese": "vie_Latn",
|
||||
"vi": "vie_Latn",
|
||||
"indonesian": "ind_Latn",
|
||||
"id": "ind_Latn",
|
||||
"thai": "tha_Thai",
|
||||
"th": "tha_Thai",
|
||||
}
|
||||
|
||||
def norm_lang(code: str) -> str:
|
||||
c = code.strip().lower()
|
||||
return LANG_MAP.get(c, code)
|
||||
|
||||
def translate_texts(texts: List[str], src_code: str, tgt_code: str,
|
||||
max_new_tokens=512, device=None, dtype=None) -> List[str]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, src_lang=src_code)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=dtype if dtype is not None else (torch.float16 if torch.cuda.is_available() else torch.float32),
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
)
|
||||
if device:
|
||||
model.to(device)
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
||||
if device or torch.cuda.is_available():
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
|
||||
forced_bos = tokenizer.convert_tokens_to_ids(tgt_code)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
forced_bos_token_id=forced_bos,
|
||||
)
|
||||
outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||
return [o.strip() for o in outs]
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Translate with facebook/nllb-200-distilled-600M")
|
||||
ap.add_argument("--text", help="Inline text to translate")
|
||||
ap.add_argument("--file", help="Path to a UTF-8 text file (one example per line)")
|
||||
ap.add_argument("--src", required=True, help="Source language (e.g. fr, fra_Latn)")
|
||||
ap.add_argument("--tgt", required=True, help="Target language (e.g. en, eng_Latn)")
|
||||
ap.add_argument("--max-new", type=int, default=512, help="Max new tokens")
|
||||
args = ap.parse_args()
|
||||
|
||||
src = norm_lang(args.src)
|
||||
tgt = norm_lang(args.tgt)
|
||||
|
||||
batch: List[str] = []
|
||||
if args.text:
|
||||
batch.append(args.text)
|
||||
if args.file:
|
||||
lines = Path(args.file).read_text(encoding="utf-8").splitlines()
|
||||
batch.extend([ln for ln in lines if ln.strip()])
|
||||
|
||||
if not batch:
|
||||
raise SystemExit("Provide --text or --file")
|
||||
|
||||
results = translate_texts(batch, src, tgt, max_new_tokens=args.max_new)
|
||||
for i, (inp, out) in enumerate(zip(batch, results), 1):
|
||||
print(f"\n--- Sample {i} ---")
|
||||
print(f"SRC [{src}]: {inp}")
|
||||
print(f"TGT [{tgt}]: {out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,38 +0,0 @@
|
||||
import regex
|
||||
from functools import lru_cache
|
||||
class SentenceSegmenter:
|
||||
|
||||
"""
|
||||
Regex sentence splitter for Latin languages, Japanese and Chinese.
|
||||
It is based on sacrebleu TokenizerV14International(BaseTokenizer).
|
||||
|
||||
Returns: a list of strings, where each string is a sentence.
|
||||
Spaces following punctuation are appended after punctuation within the sequence.
|
||||
Total number of characters in the output is the same as in the input.
|
||||
"""
|
||||
|
||||
sep = 'ŽžŽžSentenceSeparatorŽžŽž' # string that certainly won't be in src or target
|
||||
latin_terminals = '!?.'
|
||||
jap_zh_terminals = '。!?'
|
||||
terminals = latin_terminals + jap_zh_terminals
|
||||
|
||||
def __init__(self):
|
||||
# end of sentence characters:
|
||||
terminals = self.terminals
|
||||
self._re = [
|
||||
# Separate out punctuations preceeded by a non-digit.
|
||||
# If followed by space-like sequence of characters, they are
|
||||
# appended to the punctuation, not to the next sequence.
|
||||
(regex.compile(r'(\P{N})(['+terminals+r'])(\p{Z}*)'), r'\1\2\3'+self.sep),
|
||||
# Separate out punctuations followed by a non-digit
|
||||
(regex.compile(r'('+terminals+r')(\P{N})'), r'\1'+self.sep+r'\2'),
|
||||
# # Separate out symbols
|
||||
# -> no, we don't tokenize but segment the punctuation
|
||||
# (regex.compile(r'(\p{S})'), r' \1 '),
|
||||
]
|
||||
|
||||
@lru_cache(maxsize=2**16)
|
||||
def __call__(self, line):
|
||||
for (_re, repl) in self._re:
|
||||
line = _re.sub(repl, line)
|
||||
return [ t for t in line.split(self.sep) if t != '' ]
|
||||
@@ -1,466 +0,0 @@
|
||||
import sys
|
||||
|
||||
import ctranslate2
|
||||
import sentencepiece as spm
|
||||
import transformers
|
||||
import argparse
|
||||
|
||||
def generate_words(sp, step_results):
|
||||
tokens_buffer = []
|
||||
|
||||
for step_result in step_results:
|
||||
is_new_word = step_result.token.startswith("▁")
|
||||
|
||||
if is_new_word and tokens_buffer:
|
||||
word = sp.decode(tokens_buffer)
|
||||
if word:
|
||||
yield word
|
||||
tokens_buffer = []
|
||||
|
||||
tokens_buffer.append(step_result.token_id)
|
||||
|
||||
if tokens_buffer:
|
||||
word = sp.decode(tokens_buffer)
|
||||
if word:
|
||||
yield word
|
||||
|
||||
from sentence_segmenter import SentenceSegmenter
|
||||
|
||||
class LLMTranslator:
|
||||
|
||||
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None):
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
|
||||
print("Loading the model...", file=sys.stderr)
|
||||
self.generator = ctranslate2.Generator("ct2_EuroLLM-9B-Instruct/", device="cuda")
|
||||
self.sp = spm.SentencePieceProcessor("EuroLLM-9B-Instruct/tokenizer.model")
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained("EuroLLM-9B-Instruct/")
|
||||
print("...done", file=sys.stderr)
|
||||
|
||||
self.max_context_length = max_context_length
|
||||
|
||||
self.max_tokens_to_trim = self.max_context_length - 10
|
||||
self.len_ratio = len_ratio
|
||||
|
||||
# my regex sentence segmenter
|
||||
self.segmenter = SentenceSegmenter()
|
||||
|
||||
# self.max_generation_length = 512
|
||||
# self.max_prompt_length = context_length - max_generation_length
|
||||
|
||||
def start_dialog(self):
|
||||
return [{'role':'system', 'content': self.system_prompt }]
|
||||
|
||||
|
||||
def build_prompt(self, dialog):
|
||||
toks = self.tokenizer.apply_chat_template(dialog, tokenize=True, add_generation_prompt=False)
|
||||
if len(dialog) == 3:
|
||||
toks = toks[:-2]
|
||||
print("len toks:", len(toks), file=sys.stderr)
|
||||
# print(toks, file=sys.stderr)
|
||||
|
||||
c = self.tokenizer.convert_ids_to_tokens(toks)
|
||||
# print(c,file=sys.stderr)
|
||||
return c
|
||||
|
||||
def translate(self, src, tgt_forced=""):
|
||||
#src, tgt_forced = self.trim(src, tgt_forced)
|
||||
|
||||
dialog = self.start_dialog()
|
||||
dialog += [{'role':'user','content': src}]
|
||||
if tgt_forced != "":
|
||||
dialog += [{'role':'assistant','content': tgt_forced}]
|
||||
|
||||
prompt_tokens = self.build_prompt(dialog)
|
||||
if self.len_ratio is not None:
|
||||
limit_len = int(len(self.tokenizer.encode(src)) * self.len_ratio) + 10
|
||||
limit_kw = {'max_length': limit_len}
|
||||
else:
|
||||
limit_kw = {}
|
||||
step_results = self.generator.generate_tokens(
|
||||
prompt_tokens,
|
||||
**limit_kw,
|
||||
# end_token=tokenizer.eos_token,
|
||||
# sampling_temperature=0.6,
|
||||
# sampling_topk=20,
|
||||
# sampling_topp=1,
|
||||
)
|
||||
|
||||
res = []
|
||||
#output_ids = []
|
||||
for step_result in step_results:
|
||||
# is_new_word = step_result.token.startswith("▁")
|
||||
# if is_new_word and output_ids:
|
||||
# word = self.sp.decode(output_ids)
|
||||
# print(word, end=" ", flush=True, file=sys.stderr)
|
||||
# output_ids = []
|
||||
# output_ids.append(step_result.token_id)
|
||||
res.append(step_result)
|
||||
|
||||
#if output_ids:
|
||||
# word = self.sp.decode(output_ids)
|
||||
# print(word, file=sys.stderr)
|
||||
|
||||
return self.sp.decode([r.token_id for r in res])
|
||||
# print(res)
|
||||
# print([s.token for s in res], file=sys.stderr)
|
||||
# print([s.token==self.tokenizer.eos_token for s in res], file=sys.stderr)
|
||||
|
||||
class ParallelTextBuffer:
|
||||
def __init__(self, tokenizer, max_tokens, trimming="segments", init_src="", init_tgt=""):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.src_buffer = [] # list of lists
|
||||
if init_src:
|
||||
self.src_buffer.append(init_src)
|
||||
|
||||
self.tgt_buffer = [] # list of strings
|
||||
if init_tgt:
|
||||
self.tgt_buffer.append(init_tgt)
|
||||
|
||||
self.trimming = trimming
|
||||
if self.trimming == "sentences":
|
||||
self.segmenter = SentenceSegmenter()
|
||||
|
||||
def len_src(self):
|
||||
return sum(len(t) for t in self.src_buffer) + len(self.src_buffer) - 1
|
||||
|
||||
def insert(self, src, tgt):
|
||||
self.src_buffer.append(src)
|
||||
self.tgt_buffer.append(tgt)
|
||||
|
||||
def insert_src_suffix(self, s):
|
||||
if self.src_buffer:
|
||||
self.src_buffer[-1][-1] += s
|
||||
else:
|
||||
self.src_buffer.append([s])
|
||||
|
||||
def trim_sentences(self):
|
||||
# src_tok_lens = [len(self.tokenizer.encode(" ".join(b))) for b in self.src_buffer]
|
||||
# tgt_tok_lens = [len(self.tokenizer.encode(t)) for t in self.tgt_buffer]
|
||||
|
||||
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||
tgt = "".join(self.tgt_buffer)
|
||||
|
||||
src_sp_toks = self.tokenizer.encode(src)
|
||||
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||
|
||||
|
||||
|
||||
def trim_sentence(text):
|
||||
sents = self.segmenter(text)
|
||||
print("SENTS:", len(sents), sents, file=sys.stderr)
|
||||
return "".join(sents[1:])
|
||||
|
||||
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
|
||||
nsrc = trim_sentence(src)
|
||||
ntgt = trim_sentence(tgt)
|
||||
if not nsrc or not ntgt:
|
||||
print("src or tgt is empty after trimming.", file=sys.stderr)
|
||||
print("src: ", src, file=sys.stderr)
|
||||
print("tgt: ", tgt, file=sys.stderr)
|
||||
break
|
||||
src = nsrc
|
||||
tgt = ntgt
|
||||
src_sp_toks = self.tokenizer.encode(src)
|
||||
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||
print("TRIMMED SRC:", (src,), file=sys.stderr)
|
||||
print("TRIMMED TGT:", (tgt,), file=sys.stderr)
|
||||
|
||||
self.src_buffer = [src.split()]
|
||||
self.tgt_buffer = [tgt]
|
||||
return src, tgt
|
||||
|
||||
def trim_segments(self):
|
||||
print("BUFFER:", file=sys.stderr)
|
||||
for s,t in zip(self.src_buffer, self.tgt_buffer):
|
||||
print("\t", s,"...",t,file=sys.stderr) #,self.src_buffer, self.tgt_buffer, file=sys.stderr)
|
||||
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||
tgt = "".join(self.tgt_buffer)
|
||||
|
||||
src_sp_toks = self.tokenizer.encode(src)
|
||||
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||
|
||||
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
|
||||
if len(self.src_buffer) > 1 and len(self.tgt_buffer) > 1:
|
||||
self.src_buffer.pop(0)
|
||||
self.tgt_buffer.pop(0)
|
||||
else:
|
||||
break
|
||||
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||
tgt = "".join(self.tgt_buffer)
|
||||
|
||||
src_sp_toks = self.tokenizer.encode(src)
|
||||
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||
|
||||
print("TRIMMED SEGMENTS SRC:", (src,), file=sys.stderr)
|
||||
print("TRIMMED SEGMENTS TGT:", (tgt,), file=sys.stderr)
|
||||
|
||||
return src, tgt
|
||||
|
||||
def trim(self):
|
||||
if self.trimming == "sentences":
|
||||
return self.trim_sentences()
|
||||
return self.trim_segments()
|
||||
|
||||
|
||||
|
||||
class SimulLLM:
|
||||
|
||||
def __init__(self, llmtrans, min_len=0, chunk=1, trimming="sentences", language="ja", init_src="", init_tgt=""):
|
||||
self.llmtranslator = llmtrans
|
||||
|
||||
#self.src_buffer = init_src
|
||||
#self.confirmed_tgt = init_tgt
|
||||
|
||||
self.buffer = ParallelTextBuffer(self.llmtranslator.tokenizer, self.llmtranslator.max_tokens_to_trim, trimming=trimming, init_src=init_src, init_tgt=init_tgt)
|
||||
|
||||
self.last_inserted = []
|
||||
self.last_unconfirmed = ""
|
||||
|
||||
self.min_len = min_len
|
||||
|
||||
self.step = chunk
|
||||
self.language = language
|
||||
if language in ["ja", "zh"]:
|
||||
self.specific_space = ""
|
||||
else:
|
||||
self.specific_space = " "
|
||||
|
||||
def insert(self, src):
|
||||
if isinstance(src, str):
|
||||
self.last_inserted.append(src)
|
||||
else:
|
||||
self.last_inserted += src
|
||||
|
||||
def insert_suffix(self, text):
|
||||
'''
|
||||
Insert suffix of a word to the last inserted word.
|
||||
It may be because the word was split to multiple parts in the input, each with different timestamps.
|
||||
'''
|
||||
if self.last_inserted:
|
||||
self.last_inserted[-1] += text
|
||||
elif self.src_buffer:
|
||||
self.buffer.insert_src_suffix(text)
|
||||
else:
|
||||
# this shouldn't happen
|
||||
self.last_inserted.append(text)
|
||||
|
||||
def trim_longest_common_prefix(self, a,b):
|
||||
if self.language not in ["ja", "zh"]:
|
||||
a = a.split()
|
||||
b = b.split()
|
||||
i = 0
|
||||
for i,(x,y) in enumerate(zip(a,b)):
|
||||
if x != y:
|
||||
break
|
||||
if self.language in ["ja", "zh"]:
|
||||
#print("tady160",(a, b, i), file=sys.stderr)
|
||||
return a[:i], b[i:]
|
||||
else:
|
||||
return " ".join(a[:i]), " ".join(b[i:])
|
||||
|
||||
def process_iter(self):
|
||||
if self.buffer.len_src() + len(self.last_inserted) < self.min_len:
|
||||
return ""
|
||||
|
||||
src, forced_tgt = self.buffer.trim() #llmtranslator.trim(" ".join(self.src_buffer), self.confirmed_tgt)
|
||||
#self.src_buffer = self.src_buffer.split()
|
||||
#src = " ".join(self.src_buffer)
|
||||
|
||||
confirmed_out = ""
|
||||
run = False
|
||||
for i in range(self.step, len(self.last_inserted), self.step):
|
||||
for w in self.last_inserted[i-self.step:i]:
|
||||
src += " " + w
|
||||
run = True
|
||||
if not run: break
|
||||
|
||||
print("SRC",src,file=sys.stderr)
|
||||
|
||||
print("FORCED TGT",forced_tgt,file=sys.stderr)
|
||||
out = self.llmtranslator.translate(src, forced_tgt)
|
||||
print("OUT",out,file=sys.stderr)
|
||||
confirmed, unconfirmed = self.trim_longest_common_prefix(self.last_unconfirmed, out)
|
||||
self.last_unconfirmed = unconfirmed
|
||||
#print("tady", (self.confirmed_tgt, self.specific_space, confirmed), file=sys.stderr)
|
||||
if confirmed:
|
||||
# self.confirmed_tgt += self.specific_space + confirmed
|
||||
# print(confirmed_out, confirmed, file=sys.stderr)
|
||||
confirmed_out += self.specific_space + confirmed
|
||||
print("CONFIRMED NOW:",confirmed,file=sys.stderr)
|
||||
|
||||
|
||||
print(file=sys.stderr)
|
||||
print(file=sys.stderr)
|
||||
print("#################",file=sys.stderr)
|
||||
if run:
|
||||
self.buffer.insert(self.last_inserted, confirmed_out)
|
||||
self.last_inserted = []
|
||||
|
||||
ret = confirmed_out
|
||||
print("RET:",ret,file=sys.stderr)
|
||||
return ret
|
||||
|
||||
def finalize(self):
|
||||
return self.last_unconfirmed
|
||||
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input-instance', type=str, default=None, help="Filename of instances to simulate input. If not set, txt input is read from stdin.")
|
||||
#parser.add_argument('--output_instance', type=str, default=None, help="Write output as instance into this file, while also writing to stdout.")
|
||||
parser.add_argument('--min-chunk-size', type=int, default=1,
|
||||
help='Minimum number of space-delimited words to process in each LocalAgreement update. The more, the higher quality, but slower.')
|
||||
parser.add_argument('--min-len', type=int, default=1,
|
||||
help='Minimum number of space-delimited words at the beginning.')
|
||||
#parser.add_argument('--start_at', type=int, default=0, help='Skip first N words.')
|
||||
|
||||
# maybe later
|
||||
#parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
|
||||
#parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
|
||||
|
||||
lan_to_name = {
|
||||
"de": "German",
|
||||
"ja": "Japanese",
|
||||
"zh-tr": "Chinese Traditional",
|
||||
"zh-sim": "Chinese Simplified",
|
||||
"cs": "Czech",
|
||||
}
|
||||
parser.add_argument('--lan', '--language', type=str, default="de",
|
||||
help="Target language code.",
|
||||
choices=["de", "ja","zh-tr","zh-sim","cs"])
|
||||
|
||||
SrcLang = "English" # always
|
||||
TgtLang = "German"
|
||||
default_prompt="You are simultaneous interpreter from {SrcLang} to {TgtLang}. We are at a conference. It is important that you translate " + \
|
||||
"only what you hear, nothing else!"
|
||||
parser.add_argument('--sys_prompt', type=str, default=None,
|
||||
help='System prompt. If None, default one is used, depending on the language. The prompt should ')
|
||||
|
||||
default_init = "Please, go ahead, you can start with your presentation, we are ready."
|
||||
|
||||
|
||||
default_inits_tgt = {
|
||||
'de': "Bitte schön, Sie können mit Ihrer Präsentation beginnen, wir sind bereit.",
|
||||
'ja': "どうぞ、プレゼンテーションを始めてください。", # # Please go ahead and start your presentation. # this is in English
|
||||
'zh-tr': "請繼續,您可以開始您的簡報,我們已經準備好了。",
|
||||
'zh-sim': "请吧,你可以开始发言了,我们已经准备好了。",
|
||||
'cs': "Prosím, můžete začít s prezentací, jsme připraveni.",
|
||||
}
|
||||
parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
|
||||
'It can be context specific for the given input. Default is ')
|
||||
parser.add_argument('--init_prompt_tgt', type=str, default=None, help='Init translation with this target. It should be example translation of init_prompt_src. '
|
||||
' There is default init message, depending on the language.')
|
||||
|
||||
parser.add_argument('--len-threshold', type=float, default=None, help='Ratio of the length of the source and generated target, in number of sentencepiece tokens. '
|
||||
'It should reflect the target language and. If not set, no len-threshold is used.')
|
||||
|
||||
# how many times is target text longer than English
|
||||
lan_thresholds = {
|
||||
'de': 1.3, # 12751/9817 ... the proportion of subword tokens for ACL6060 dev de vs. en text, for EuroLLM-9B-Instruct tokenizer
|
||||
'ja': 1.34, # 13187/9817
|
||||
'zh': 1.23, # 12115/9817
|
||||
'zh-tr': 1.23, # 12115/9817
|
||||
'zh-sim': 1.23, # 12115/9817
|
||||
# 'cs': I don't know # guessed
|
||||
}
|
||||
parser.add_argument('--language-specific-len-threshold', default=False, action="store_true",
|
||||
help='Use language-specific length threshold, e.g. 1.3 for German.')
|
||||
|
||||
parser.add_argument("--max-context-length", type=int, default=4096, help="Maximum number of tokens in the model to use.")
|
||||
|
||||
parser.add_argument("--buffer_trimming", type=str, default="sentences", choices=["segments","sentences"], help="Buffer trimming strategy.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sys_prompt is None:
|
||||
TgtLang = lan_to_name[args.lan]
|
||||
sys_prompt = default_prompt.format(SrcLang=SrcLang, TgtLang=TgtLang)
|
||||
else:
|
||||
sys_prompt = args.sys_prompt
|
||||
|
||||
if args.init_prompt_src is None:
|
||||
init_src = default_init.split()
|
||||
if args.init_prompt_tgt is None:
|
||||
init_tgt = default_inits_tgt[args.lan]
|
||||
if args.lan == "ja":
|
||||
init_src = 'Please go ahead and start your presentation.'.split()
|
||||
print("WARNING: Default init_prompt_src not set and language is Japanese. The init_src prompt changed to be more verbose.", file=sys.stderr)
|
||||
else:
|
||||
print("WARNING: init_prompt_tgt is used, init_prompt_src is None, the default one. It may be wrong!", file=sys.stderr)
|
||||
init_tgt = args.init_prompt_tgt
|
||||
else:
|
||||
init_src = args.init_prompt_src.split()
|
||||
if args.init_prompt_tgt is None:
|
||||
print("WARNING: init_prompt_src is used, init_prompt_tgt is None, so the default one is used. It may be wrong!", file=sys.stderr)
|
||||
init_tgt = default_inits_tgt[args.lan]
|
||||
else:
|
||||
init_tgt = args.init_prompt_tgt
|
||||
|
||||
print("INFO: System prompt:", sys_prompt, file=sys.stderr)
|
||||
print("INFO: Init prompt src:", init_src, file=sys.stderr)
|
||||
print("INFO: Init prompt tgt:", init_tgt, file=sys.stderr)
|
||||
|
||||
if args.language_specific_len_threshold:
|
||||
if args.len_threshold is not None:
|
||||
print("ERROR: --len-threshold is set, but --language-specific-len-threshold is also set. Only one can be used.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
len_threshold = lan_thresholds[args.lan]
|
||||
else:
|
||||
len_threshold = args.len_threshold
|
||||
|
||||
llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold)
|
||||
lan = args.lan if not args.lan.startswith("zh") else "zh"
|
||||
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
|
||||
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
|
||||
)
|
||||
|
||||
# two input options
|
||||
if args.input_instance is not None:
|
||||
print("INFO: Reading input from file", args.input_instance, file=sys.stderr)
|
||||
import json
|
||||
with open(args.input_instance, "r") as f:
|
||||
instance = json.load(f)
|
||||
|
||||
asr_source = instance["prediction"]
|
||||
timestamps = instance["delays"]
|
||||
elapsed = instance["elapsed"]
|
||||
|
||||
yield_ts_words = zip(timestamps, timestamps, elapsed, asr_source.split())
|
||||
else:
|
||||
print("INFO: Reading stdin in txt format", file=sys.stderr)
|
||||
def yield_input():
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
ts, beg, end, *_ = line.split()
|
||||
text = line[len(ts)+len(beg)+len(end)+3:]
|
||||
ts = float(ts)
|
||||
# in rare cases, the first word is a suffix of the previous word, that was split to multiple parts
|
||||
if text[0] != " ":
|
||||
first, *words = text.split()
|
||||
yield (ts, beg, end, " "+first) # marking the first word with " ", so that it can be later detected and inserted as suffix
|
||||
else:
|
||||
words = text.split()
|
||||
for w in words:
|
||||
yield (ts, beg, end, w)
|
||||
yield_ts_words = yield_input()
|
||||
|
||||
#i = 0
|
||||
for t,b,e,w in yield_ts_words:
|
||||
if w.startswith(" "): # it is suffix of the previous word
|
||||
w = w[1:]
|
||||
simul.insert_suffix(w)
|
||||
continue
|
||||
simul.insert(w)
|
||||
out = simul.process_iter()
|
||||
if out:
|
||||
print(t,b,e,out,flush=True)
|
||||
# if i > 50:
|
||||
# break
|
||||
# i += 1
|
||||
out = simul.finalize()
|
||||
print(t,b,e,out,flush=True)
|
||||
@@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5):
|
||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||
except (urllib.error.URLError, socket.timeout) as e:
|
||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||
return False
|
||||
return None
|
||||
finally:
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
elif not warmup_file:
|
||||
return False
|
||||
return None
|
||||
|
||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||
return False
|
||||
return None
|
||||
|
||||
try:
|
||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load audio file: {e}")
|
||||
return False
|
||||
return None
|
||||
return audio
|
||||
|
||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
|
||||
@@ -184,7 +184,7 @@ body {
|
||||
|
||||
.settings {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
}
|
||||
@@ -198,23 +198,27 @@ body {
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput,
|
||||
#themeSelector {
|
||||
#themeSelector,
|
||||
#microphoneSelect {
|
||||
font-size: 16px;
|
||||
padding: 5px 8px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
background-color: var(--button-bg);
|
||||
color: var(--text);
|
||||
max-height: 34px;
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#websocketInput {
|
||||
width: 220px;
|
||||
#microphoneSelect {
|
||||
width: 100%;
|
||||
max-width: 190px;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus,
|
||||
#themeSelector:focus {
|
||||
#themeSelector:focus,
|
||||
#microphoneSelect:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||
@@ -247,9 +251,9 @@ label {
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 17px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
@@ -400,3 +404,57 @@ label {
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
@media (max-width: 768px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.field {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 100px;
|
||||
max-width: 160px;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
body {
|
||||
margin: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
max-width: 140px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,61 +1,73 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">WebSocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script src="/web/live_transcription.js"></script>
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script src="/web/live_transcription.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
</html>
|
||||
@@ -18,6 +18,8 @@ let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
@@ -31,6 +33,7 @@ const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
@@ -82,6 +85,61 @@ if (darkMq && darkMq.addEventListener) {
|
||||
darkMq.addListener(handleOsThemeChange);
|
||||
}
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
if (!microphoneSelect) return;
|
||||
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||
|
||||
console.log(`Selected microphone: ${deviceName}`);
|
||||
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||
|
||||
if (isRecording) {
|
||||
statusText.textContent = "Switching microphone... Please wait.";
|
||||
stopRecording().then(() => {
|
||||
setTimeout(() => {
|
||||
toggleRecording();
|
||||
}, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function fmt1(x) {
|
||||
const n = Number(x);
|
||||
@@ -377,7 +435,11 @@ async function startRecording() {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
@@ -516,3 +578,22 @@ function updateUI() {
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
if (microphoneSelect) {
|
||||
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||
}
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Could not enumerate microphones on load:", error);
|
||||
}
|
||||
});
|
||||
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log('Device change detected, re-enumerating microphones');
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -16,32 +16,25 @@ def get_web_interface_html():
|
||||
def get_inline_ui_html():
|
||||
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||
try:
|
||||
# Load HTML template
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
|
||||
# Load CSS and embed it
|
||||
html_content = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
|
||||
css_content = f.read()
|
||||
|
||||
# Load JS and embed it
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||
js_content = f.read()
|
||||
|
||||
# Load SVG files and convert to data URIs
|
||||
# SVG files
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||
system_svg = f.read()
|
||||
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
|
||||
light_svg = f.read()
|
||||
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||
dark_svg = f.read()
|
||||
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
# Replace external references with embedded content
|
||||
# Replace external references
|
||||
html_content = html_content.replace(
|
||||
'<link rel="stylesheet" href="/web/live_transcription.css" />',
|
||||
f'<style>\n{css_content}\n</style>'
|
||||
@@ -52,7 +45,7 @@ def get_inline_ui_html():
|
||||
f'<script>\n{js_content}\n</script>'
|
||||
)
|
||||
|
||||
# Replace SVG references with data URIs
|
||||
# Replace SVG references
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/system_mode.svg" alt="" />',
|
||||
f'<img src="{system_data_uri}" alt="" />'
|
||||
|
||||
Reference in New Issue
Block a user