Compare commits
91 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12a69205ed | ||
|
|
1f684cdd97 | ||
|
|
290470dd60 | ||
|
|
425ac7b51d | ||
|
|
0382cfbeba | ||
|
|
9b1e061b32 | ||
|
|
b4abc158b9 | ||
|
|
5832d7433d | ||
|
|
3736458503 | ||
|
|
374618e050 | ||
|
|
543972ef38 | ||
|
|
971f8473eb | ||
|
|
8434ef5efc | ||
|
|
73f36cc0ef | ||
|
|
a7db39d999 | ||
|
|
a153e11fe0 | ||
|
|
ca6f9246cc | ||
|
|
d080d675a8 | ||
|
|
40bff38933 | ||
|
|
2fe3ca0188 | ||
|
|
545ea15c9a | ||
|
|
8cbaeecc75 | ||
|
|
70e854b346 | ||
|
|
d55490cd27 | ||
|
|
1fa9e1f656 | ||
|
|
994f30e1ed | ||
|
|
b22478c0b4 | ||
|
|
94c34efd90 | ||
|
|
32099b9275 | ||
|
|
9fc6654a4a | ||
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
cd9a32a36b | ||
|
|
6caf3e0485 | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
4d7c487614 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 | ||
|
|
2a27d2030a | ||
|
|
cd160caaa1 | ||
|
|
d27b5eb23e | ||
|
|
f9d704a900 | ||
|
|
2f6e00f512 | ||
|
|
5aa312e437 | ||
|
|
ebaf36a8be | ||
|
|
babe93b99a | ||
|
|
a4e9f3cab7 | ||
|
|
b06866877a | ||
|
|
967cdfebc8 | ||
|
|
3c11c60126 | ||
|
|
2963e8a757 | ||
|
|
cb2d4ea88a | ||
|
|
add7ea07ee | ||
|
|
da8726b2cb | ||
|
|
3358877054 | ||
|
|
1f7798c7c1 | ||
|
|
c7b3bb5e58 | ||
|
|
f661f21675 | ||
|
|
b6164aa59b | ||
|
|
4209d7f7c0 | ||
|
|
334b338ab0 | ||
|
|
72f33be6f2 | ||
|
|
84890b8e61 | ||
|
|
c6668adcf3 | ||
|
|
a178ed5c22 | ||
|
|
7601c74c9c | ||
|
|
fad9ee4d21 | ||
|
|
d1a9913c47 | ||
|
|
e4ca2623cb | ||
|
|
9c1bf37960 | ||
|
|
f46528471b | ||
|
|
191680940b | ||
|
|
ee02afec56 | ||
|
|
a458028de2 | ||
|
|
abd8f2c269 | ||
|
|
f3ad4e39e4 | ||
|
|
e0a5cbf0e7 | ||
|
|
953697cd86 |
3
.gitignore
vendored
@@ -137,4 +137,5 @@ run_*.sh
|
||||
test_*.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
25
DEV_NOTES.md
@@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
# 2. Translation: Faster model for each system
|
||||
|
||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
## Benchmark Results
|
||||
|
||||
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||
|
||||
### Standard Transformers vs CTranslate2
|
||||
|
||||
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||
|-----------|-------------------------|---------------------------|---------|
|
||||
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||
|
||||
**Results:**
|
||||
- Total Standard time: 4.1068s
|
||||
- Total CTranslate2 time: 8.5476s
|
||||
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||
|
||||
|
||||
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
@@ -67,4 +88,4 @@ ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
```
|
||||
|
||||
85
README.md
@@ -18,8 +18,9 @@ Real-time speech transcription directly to your browser, with a ready-to-use bac
|
||||
|
||||
#### Powered by Leading Research:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
|
||||
- [SimulStreaming](https://github.com/ufalSimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||
@@ -39,14 +40,7 @@ Real-time speech transcription directly to your browser, with a ready-to-use bac
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
> **FFmpeg is required** and must be installed before using WhisperLiveKit
|
||||
>
|
||||
> | OS | How to install |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
|
||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
@@ -60,7 +54,15 @@ pip install whisperlivekit
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
Go to `chrome-extension` for instructions.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
#### Optional Dependencies
|
||||
|
||||
@@ -68,6 +70,7 @@ pip install whisperlivekit
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
@@ -82,11 +85,11 @@ See **Parameters & Configuration** below on how to use them.
|
||||
**Command-line Interface**: Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Use better model than default (small)
|
||||
whisperlivekit-server --model large-v3
|
||||
# Large model and translate from french to danish
|
||||
whisperlivekit-server --model large-v3 --language fr --target-language da
|
||||
|
||||
# Advanced configuration with diarization and language
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
# Diarization and server listening on */80
|
||||
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
|
||||
@@ -133,24 +136,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
## Parameters & Configuration
|
||||
|
||||
An important list of parameters can be changed. But what *should* you change?
|
||||
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
|
||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
||||
- `--warmup-file`, if you have one
|
||||
- `--task translate`, to translate in english
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
||||
- `--diarization`, if you want to use it.
|
||||
|
||||
The rest I don't recommend. But below are your options.
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. | `small` |
|
||||
| `--language` | Source language code or `auto` | `auto` |
|
||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||
| `--backend` | Processing backend | `simulstreaming` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` |
|
||||
| `--model-dir` | Directory containing Whisper model.bin and other files. Overrides `--model`. | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
|
||||
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
@@ -158,11 +153,25 @@ The rest I don't recommend. But below are your options.
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
@@ -174,7 +183,8 @@ The rest I don't recommend. But below are your options.
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
@@ -182,19 +192,10 @@ The rest I don't recommend. But below are your options.
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> For diarization using Diart, you need access to pyannote.audio models:
|
||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
>4. Login with HuggingFace: `huggingface-cli login`
|
||||
|
||||
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
|
||||
BIN
architecture.png
|
Before Width: | Height: | Size: 368 KiB After Width: | Height: | Size: 406 KiB |
@@ -1,4 +1,4 @@
|
||||
# Available model sizes:
|
||||
# Available Whisper model sizes:
|
||||
|
||||
- tiny.en (english only)
|
||||
- tiny
|
||||
@@ -58,6 +58,7 @@
|
||||
- `small`: ~2GB VRAM
|
||||
- `medium`: ~5GB VRAM
|
||||
- `large`: ~10GB VRAM
|
||||
- `large‑v3‑turbo`: ~6GB VRAM
|
||||
|
||||
**Audio Quality Impact**:
|
||||
- Clean, clear audio: smaller models may suffice
|
||||
@@ -69,4 +70,40 @@
|
||||
2. Limited resources or need speed? → `small` or smaller
|
||||
3. Good hardware and want best quality? → `large-v3`
|
||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
|
||||
|
||||
_______________________
|
||||
|
||||
# Translation Models and Backend
|
||||
|
||||
**Language Support**: ~200 languages
|
||||
|
||||
## Distilled Model Sizes Available
|
||||
|
||||
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||
|-------|------|------------|-------------|-------------|---------|
|
||||
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||
|
||||
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||
|
||||
## Backend Performance
|
||||
|
||||
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||
|---------|---------------|--------------|--------------|
|
||||
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||
| Transformers | Baseline | High | None |
|
||||
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||
|
||||
**Metrics**:
|
||||
- CTranslate2: 50-100+ tokens/sec
|
||||
- Transformers: 10-30 tokens/sec
|
||||
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||
|
||||
## Quick Decision Matrix
|
||||
|
||||
**Choose 600M**: Limited resources, close to 0 lag
|
||||
**Choose 1.3B**: Quality matters
|
||||
**Choose Transformers**: On Apple Silicon
|
||||
|
||||
|
||||
19
chrome-extension/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## WhisperLiveKit Chrome Extension v0.1.1
|
||||
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
|
||||
|
||||
> Currently, only the tab audio is captured; your microphone audio is not recorded.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||
|
||||
## Running this extension
|
||||
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||
|
||||
|
||||
## Devs:
|
||||
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
|
||||
- https://issues.chromium.org/issues/40926394
|
||||
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
|
||||
- https://issues.chromium.org/issues/40916430
|
||||
|
||||
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)
|
||||
9
chrome-extension/background.js
Normal file
@@ -0,0 +1,9 @@
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
BIN
chrome-extension/demo-extension.png
Normal file
|
After Width: | Height: | Size: 5.8 MiB |
BIN
chrome-extension/icons/icon128.png
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
chrome-extension/icons/icon16.png
Normal file
|
After Width: | Height: | Size: 376 B |
BIN
chrome-extension/icons/icon32.png
Normal file
|
After Width: | Height: | Size: 823 B |
BIN
chrome-extension/icons/icon48.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
23
chrome-extension/manifest.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "WhisperLiveKit Tab Capture",
|
||||
"version": "1.0",
|
||||
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||
"icons": {
|
||||
"16": "icons/icon16.png",
|
||||
"32": "icons/icon32.png",
|
||||
"48": "icons/icon48.png",
|
||||
"128": "icons/icon128.png"
|
||||
},
|
||||
"action": {
|
||||
"default_title": "WhisperLiveKit Tab Capture",
|
||||
"default_popup": "live_transcription.html"
|
||||
},
|
||||
"permissions": [
|
||||
"scripting",
|
||||
"tabCapture",
|
||||
"offscreen",
|
||||
"activeTab",
|
||||
"storage"
|
||||
]
|
||||
}
|
||||
12
chrome-extension/requestPermissions.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Request Permissions</title>
|
||||
<script src="requestPermissions.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<button id="requestMicrophone">Request Microphone</button>
|
||||
</body>
|
||||
</html>
|
||||
17
chrome-extension/requestPermissions.js
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Requests user permission for microphone access.
|
||||
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
|
||||
*/
|
||||
async function getUserPermission() {
|
||||
console.log("Getting user permission for microphone access...");
|
||||
await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state == "granted") {
|
||||
window.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Call the function to request microphone permission
|
||||
getUserPermission();
|
||||
29
chrome-extension/sidepanel.js
Normal file
@@ -0,0 +1,29 @@
|
||||
console.log("sidepanel.js");
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "requestPermissions.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
BIN
demo.png
|
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 985 KiB |
264
docs/API.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# WhisperLiveKit WebSocket API Documentation
|
||||
|
||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||
This documentation is intended for devs who want to build custom frontends.
|
||||
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
|
||||
---
|
||||
|
||||
## Legacy API (Current)
|
||||
|
||||
### Message Structure
|
||||
|
||||
The current API sends complete state snapshots on each update (several time per second)
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": str,
|
||||
"status": str,
|
||||
"lines": [
|
||||
{
|
||||
"speaker": int,
|
||||
"text": str,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"translation": str | null,
|
||||
"detected_language": str
|
||||
}
|
||||
],
|
||||
"buffer_transcription": str,
|
||||
"buffer_diarization": str,
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## New API (Under Development)
|
||||
|
||||
### Philosophy
|
||||
|
||||
Principles:
|
||||
|
||||
- **Incremental Updates**: Only updates and new segments are sent
|
||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||
|
||||
|
||||
## Message Format
|
||||
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription" | "no_audio_detected",
|
||||
"segments": [
|
||||
{
|
||||
"id": number,
|
||||
"speaker": number,
|
||||
"text": string,
|
||||
"start_speaker": float,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"language": string | null,
|
||||
"translation": string,
|
||||
"words": [
|
||||
{
|
||||
"text": string,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"validated": {
|
||||
"text": boolean,
|
||||
"speaker": boolean,
|
||||
}
|
||||
}
|
||||
],
|
||||
"buffer": {
|
||||
"transcription": string,
|
||||
"diarization": string,
|
||||
"translation": string
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Other Message Types
|
||||
|
||||
#### Config Message (sent on connection)
|
||||
```json
|
||||
{
|
||||
"type": "config",
|
||||
"useAudioWorklet": true / false
|
||||
}
|
||||
```
|
||||
|
||||
#### Ready to Stop Message (sent after processing complete)
|
||||
```json
|
||||
{
|
||||
"type": "ready_to_stop"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Descriptions
|
||||
|
||||
### Segment Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||
|
||||
### Word Object
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `string` | The word text. |
|
||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||
|
||||
### Buffer Object (Per-Segment)
|
||||
|
||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||
|
||||
### Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `active_transcription` | Normal operation, transcription is active. |
|
||||
| `no_audio_detected` | No audio has been detected yet. |
|
||||
|
||||
---
|
||||
|
||||
## Update Behavior
|
||||
|
||||
### Incremental Updates
|
||||
|
||||
The API sends **only changed or new segments**. Clients should:
|
||||
|
||||
1. Maintain a local map of segments by ID
|
||||
2. When receiving an update, merge/update segments by ID
|
||||
3. Render only the changed segments
|
||||
|
||||
### Language Detection
|
||||
|
||||
When language is detected for a segment:
|
||||
|
||||
```jsonc
|
||||
// Update 1: No language yet
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||
]
|
||||
}
|
||||
|
||||
// Update 2: Same segment ID, language now detected
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||
|
||||
### Buffer Behavior
|
||||
|
||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||
|
||||
#### Example: Translation with diarization and translation
|
||||
|
||||
```jsonc
|
||||
// Update 1
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "Hello world, how are",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " you on",
|
||||
"translation": "Bonjour le monde"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||
|
||||
|
||||
// Update 2
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": " you on this",
|
||||
"translation": "Bonjour tout le monde",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " beautiful day",
|
||||
"translation": ",comment"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||
```
|
||||
|
||||
### Silence Segments
|
||||
|
||||
Silence is represented with the speaker id = `-2`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"id": 5,
|
||||
"speaker": -2,
|
||||
"text": "",
|
||||
"start": 10.5,
|
||||
"end": 12.3
|
||||
}
|
||||
```
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.8"
|
||||
version = "0.2.12"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -50,7 +50,7 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", "whisperlivekit.translation"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
|
||||
38
sync_extension.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def sync_extension_files():
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
extension_dir = Path("chrome-extension")
|
||||
|
||||
files_to_sync = [
|
||||
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||
]
|
||||
|
||||
svg_files = [
|
||||
"system_mode.svg",
|
||||
"light_mode.svg",
|
||||
"dark_mode.svg",
|
||||
"settings.svg"
|
||||
]
|
||||
|
||||
for file in files_to_sync:
|
||||
src_path = web_dir / file
|
||||
dest_path = extension_dir / file
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
for svg_file in svg_files:
|
||||
src_path = web_dir / "src" / svg_file
|
||||
dest_path = extension_dir / "web" / "src" / svg_file
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sync_extension_files()
|
||||
@@ -4,18 +4,41 @@ from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output, format_time
|
||||
# Set up logging once
|
||||
from whisperlivekit.results_formater import format_output
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
cut_sample = int(cut_sec * 16000)
|
||||
|
||||
for ind, pcm_array in enumerate(cumulative_pcm):
|
||||
if (cumulative_len + len(pcm_array)) >= cut_sample:
|
||||
cut_chunk = cut_sample - cumulative_len
|
||||
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
|
||||
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
|
||||
return before, after
|
||||
cumulative_len += len(pcm_array)
|
||||
return np.concatenate(cumulative_pcm), []
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
@@ -38,89 +61,84 @@ class AudioProcessor:
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
self.last_ffmpeg_activity = time()
|
||||
self.ffmpeg_health_check_interval = 5
|
||||
self.ffmpeg_max_idle_time = 10
|
||||
self.debug = False
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = False
|
||||
self.silence_duration = 0.0
|
||||
self.tokens = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
self.last_validated_token = 0
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = Transcript()
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.lock = asyncio.Lock()
|
||||
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||
self.beg_loop = 0.0 #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = ""
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
self.diarization_before_transcription = False
|
||||
|
||||
if self.diarization_before_transcription:
|
||||
self.cumulative_pcm = []
|
||||
self.last_start = 0.0
|
||||
self.last_end = 0.0
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
self.vac_model = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
)
|
||||
|
||||
async def handle_ffmpeg_error(error_type: str):
|
||||
logger.error(f"FFmpeg error: {error_type}")
|
||||
self._ffmpeg_error = error_type
|
||||
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.ffmpeg_manager = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self._ffmpeg_error = None
|
||||
|
||||
|
||||
if not self.is_pcm_input:
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
)
|
||||
async def handle_ffmpeg_error(error_type: str):
|
||||
logger.error(f"FFmpeg error: {error_type}")
|
||||
self._ffmpeg_error = error_type
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
# Task references
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self.translation_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
|
||||
# Initialize transcription engine if enabled
|
||||
self.transcription = None
|
||||
self.translation = None
|
||||
self.diarization = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
|
||||
# Initialize diarization engine if enabled
|
||||
self.transcription = online_factory(self.args, models.asr)
|
||||
self.sep = self.transcription.asr.sep
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
|
||||
if models.translation_model:
|
||||
self.translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.sep = sep
|
||||
|
||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||
"""Thread-safe update of diarization with new data."""
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = end_attributed_speaker
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
async def add_dummy_token(self):
|
||||
"""Placeholder token when no transcription is available."""
|
||||
async with self.lock:
|
||||
current_time = time() - self.beg_loop if self.beg_loop else 0
|
||||
current_time = time() - self.beg_loop
|
||||
self.tokens.append(ASRToken(
|
||||
start=current_time, end=current_time + 1,
|
||||
text=".", speaker=-1, is_dummy=True
|
||||
@@ -141,33 +159,36 @@ class AudioProcessor:
|
||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||
|
||||
return {
|
||||
"tokens": self.tokens.copy(),
|
||||
"buffer_transcription": self.buffer_transcription,
|
||||
"buffer_diarization": self.buffer_diarization,
|
||||
"end_buffer": self.end_buffer,
|
||||
"end_attributed_speaker": self.end_attributed_speaker,
|
||||
"sep": self.sep,
|
||||
"remaining_time_transcription": remaining_transcription,
|
||||
"remaining_time_diarization": remaining_diarization
|
||||
}
|
||||
return State(
|
||||
tokens=self.tokens.copy(),
|
||||
last_validated_token=self.last_validated_token,
|
||||
translated_segments=self.translated_segments.copy(),
|
||||
buffer_transcription=self.buffer_transcription,
|
||||
end_buffer=self.end_buffer,
|
||||
end_attributed_speaker=self.end_attributed_speaker,
|
||||
remaining_time_transcription=remaining_transcription,
|
||||
remaining_time_diarization=remaining_diarization
|
||||
)
|
||||
|
||||
async def reset(self):
|
||||
"""Reset all state variables to initial values."""
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.buffer_transcription = self.buffer_diarization = ""
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = Transcript()
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.beg_loop = time()
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
"""Read audio data from FFmpeg stdout and process it."""
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
beg = time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check if FFmpeg is running
|
||||
state = await self.ffmpeg_manager.get_state()
|
||||
if self.is_stopping:
|
||||
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
|
||||
break
|
||||
|
||||
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
|
||||
if state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot read data")
|
||||
break
|
||||
@@ -175,100 +196,41 @@ class AudioProcessor:
|
||||
logger.info("FFmpeg is stopped")
|
||||
break
|
||||
elif state != FFmpegState.RUNNING:
|
||||
logger.warning(f"FFmpeg is in {state} state, waiting...")
|
||||
await asyncio.sleep(0.5)
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
|
||||
current_time = time()
|
||||
elapsed_time = math.floor((current_time - beg) * 10) / 10
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096)
|
||||
elapsed_time = max(0.0, current_time - beg)
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
|
||||
beg = current_time
|
||||
|
||||
chunk = await self.ffmpeg_manager.read_data(buffer_size)
|
||||
|
||||
if not chunk:
|
||||
if self.is_stopping:
|
||||
logger.info("FFmpeg stdout closed, stopping.")
|
||||
break
|
||||
else:
|
||||
# No data available, but not stopping - FFmpeg might be restarting
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# No data currently available
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
await self.handle_pcm_data()
|
||||
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
# Process audio chunk
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||
|
||||
res = None
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
if res.get('end', 0) > res.get('start', 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
|
||||
if silence_buffer:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
|
||||
# Sleep if no processing is happening
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("ffmpeg_stdout_reader cancelled.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
# Try to recover by waiting a bit
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if we should exit
|
||||
if self.is_stopping:
|
||||
break
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors.")
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into transcription_queue.")
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into diarization_queue.")
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
if self.diarization:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self):
|
||||
"""Process audio chunks for transcription."""
|
||||
self.sep = self.online.asr.sep
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
@@ -278,65 +240,59 @@ class AudioProcessor:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
if not self.online:
|
||||
logger.warning("Transcription processor: self.online not initialized.")
|
||||
self.transcription_queue.task_done()
|
||||
continue
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if type(item) is Silence:
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
self.transcription.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.transcription.new_speaker(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
|
||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||
|
||||
|
||||
|
||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||
|
||||
# Get buffer information
|
||||
_buffer_transcript_obj = self.online.get_buffer()
|
||||
buffer_text = _buffer_transcript_obj.text
|
||||
_buffer_transcript = self.transcription.get_buffer()
|
||||
buffer_text = _buffer_transcript.text
|
||||
|
||||
if new_tokens:
|
||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||
if buffer_text.startswith(validated_text):
|
||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
|
||||
if _buffer_transcript_obj.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript_obj.end)
|
||||
if _buffer_transcript.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript.end)
|
||||
|
||||
candidate_end_times.append(current_audio_processed_upto)
|
||||
|
||||
new_end_buffer = max(candidate_end_times)
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.buffer_transcription = _buffer_transcript
|
||||
self.end_buffer = max(candidate_end_times)
|
||||
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer_text, new_end_buffer, self.sep
|
||||
)
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -344,13 +300,22 @@ class AudioProcessor:
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
if self.is_stopping:
|
||||
logger.info("Transcription processor finishing due to stopping flag.")
|
||||
if self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = ""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
if self.diarization_before_transcription:
|
||||
self.current_speaker = 0
|
||||
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0))
|
||||
while True:
|
||||
try:
|
||||
item = await self.diarization_queue.get()
|
||||
@@ -358,30 +323,49 @@ class AudioProcessor:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
|
||||
if type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
elif type(item) is Silence:
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
if self.diarization_before_transcription:
|
||||
segments = diarization_obj.get_segments()
|
||||
self.cumulative_pcm.append(pcm_array)
|
||||
if segments:
|
||||
last_segment = segments[-1]
|
||||
if last_segment.speaker != self.current_speaker:
|
||||
cut_sec = last_segment.start - self.last_end
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
|
||||
self.current_speaker = last_segment.speaker
|
||||
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start))
|
||||
|
||||
cut_sec = last_segment.end - last_segment.start
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
self.last_start = last_segment.start
|
||||
self.last_end = last_segment.end
|
||||
else:
|
||||
cut_sec = last_segment.end - self.last_end
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
self.last_end = last_segment.end
|
||||
elif not self.diarization_before_transcription:
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -391,100 +375,112 @@ class AudioProcessor:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self):
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
item = await self.translation_queue.get() #block until at least 1 token
|
||||
if item is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
self.translation.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [item]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
for additional_token in additional_tokens:
|
||||
if additional_token is SENTINEL:
|
||||
sentinel_found = True
|
||||
break
|
||||
elif type(additional_token) is Silence:
|
||||
self.translation.insert_silence(additional_token.duration)
|
||||
continue
|
||||
else:
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = await asyncio.to_thread(self.translation.process)
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
|
||||
if sentinel_found:
|
||||
logger.debug("Translation processor received sentinel in batch. Finishing.")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and item is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
last_sent_trans = None
|
||||
last_sent_diar = None
|
||||
while True:
|
||||
try:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED and self._ffmpeg_error:
|
||||
yield {
|
||||
"status": "error",
|
||||
"error": f"FFmpeg error: {self._ffmpeg_error}",
|
||||
"lines": [],
|
||||
"buffer_transcription": "",
|
||||
"buffer_diarization": "",
|
||||
"remaining_time_transcription": 0,
|
||||
"remaining_time_diarization": 0
|
||||
}
|
||||
if self._ffmpeg_error:
|
||||
yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
|
||||
self._ffmpeg_error = None
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Get current state
|
||||
|
||||
state = await self.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
sleep(0.5)
|
||||
state = await self.get_current_state()
|
||||
tokens = state["tokens"]
|
||||
|
||||
# Format output
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
|
||||
lines, undiarized_text = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
diarization = self.args.diarization,
|
||||
debug = self.debug
|
||||
current_time = time() - self.beg_loop,
|
||||
args = self.args,
|
||||
sep=self.sep
|
||||
)
|
||||
# Handle undiarized text
|
||||
if lines and lines[-1].speaker == -2:
|
||||
buffer_transcription = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
combined = sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined += sep
|
||||
await self.update_diarization(end_attributed_speaker, combined)
|
||||
buffer_diarization = combined
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
response_status = "active_transcription"
|
||||
final_lines_for_response = lines.copy()
|
||||
|
||||
if not tokens and not buffer_transcription and not buffer_diarization:
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
response_status = "no_audio_detected"
|
||||
final_lines_for_response = []
|
||||
elif response_status == "active_transcription" and not final_lines_for_response:
|
||||
final_lines_for_response = [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(state.get("end_buffer", 0)),
|
||||
"end": format_time(state.get("end_buffer", 0)),
|
||||
"diff": 0
|
||||
}]
|
||||
lines = []
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.end_buffer,
|
||||
end=state.end_buffer
|
||||
)]
|
||||
|
||||
response = {
|
||||
"status": response_status,
|
||||
"lines": final_lines_for_response,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
|
||||
}
|
||||
|
||||
current_response_signature = f"{response_status} | " + \
|
||||
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||
f" | {buffer_transcription} | {buffer_diarization}"
|
||||
|
||||
trans = state["remaining_time_transcription"]
|
||||
diar = state["remaining_time_diarization"]
|
||||
should_push = (
|
||||
current_response_signature != self.last_response_content
|
||||
or last_sent_trans is None
|
||||
or round(trans, 1) != round(last_sent_trans, 1)
|
||||
or round(diar, 1) != round(last_sent_diar, 1)
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
yield response
|
||||
self.last_response_content = current_response_signature
|
||||
last_sent_trans = trans
|
||||
last_sent_diar = diar
|
||||
self.last_response_content = response
|
||||
|
||||
# Check for termination condition
|
||||
if self.is_stopping:
|
||||
@@ -496,50 +492,50 @@ class AudioProcessor:
|
||||
|
||||
if all_processors_done:
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
final_state = await self.get_current_state()
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
"""Create and start processing tasks."""
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator():
|
||||
yield {
|
||||
"status": "error",
|
||||
"error": "FFmpeg failed to start. Please check that FFmpeg is installed.",
|
||||
"lines": [],
|
||||
"buffer_transcription": "",
|
||||
"buffer_diarization": "",
|
||||
"remaining_time_transcription": 0,
|
||||
"remaining_time_diarization": 0
|
||||
}
|
||||
return error_generator()
|
||||
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||
if not self.is_pcm_input:
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator():
|
||||
yield FrontData(
|
||||
status="error",
|
||||
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||
)
|
||||
return error_generator()
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
|
||||
if self.args.transcription and self.online:
|
||||
if self.transcription:
|
||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||
|
||||
if self.args.diarization and self.diarization:
|
||||
if self.diarization:
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
|
||||
if self.translation:
|
||||
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
# Monitor overall system health
|
||||
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
||||
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
||||
@@ -560,15 +556,6 @@ class AudioProcessor:
|
||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||
else:
|
||||
logger.info(f"{task_name} completed normally.")
|
||||
|
||||
# Check FFmpeg status through the manager
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, notifying results formatter")
|
||||
# FFmpeg manager will handle its own recovery
|
||||
elif ffmpeg_state == FFmpegState.STOPPED and not self.is_stopping:
|
||||
logger.warning("FFmpeg unexpectedly stopped, attempting restart")
|
||||
await self.ffmpeg_manager.restart()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog task cancelled.")
|
||||
@@ -578,18 +565,24 @@ class AudioProcessor:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
for task in self.all_tasks_for_cleanup:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
|
||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
try:
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
||||
if self.diarization:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
@@ -603,18 +596,87 @@ class AudioProcessor:
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
self.is_stopping = True
|
||||
# Signal FFmpeg manager to stop accepting data
|
||||
await self.ffmpeg_manager.stop()
|
||||
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
await self.ffmpeg_manager.stop()
|
||||
|
||||
return
|
||||
|
||||
if self.is_stopping:
|
||||
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
|
||||
return
|
||||
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
if self.is_pcm_input:
|
||||
self.pcm_buffer.extend(message)
|
||||
await self.handle_pcm_data()
|
||||
else:
|
||||
if not self.ffmpeg_manager:
|
||||
logger.error("FFmpeg manager not initialized for non-PCM input.")
|
||||
return
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
|
||||
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
|
||||
|
||||
if aligned_chunk_size == 0:
|
||||
return
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
||||
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
|
||||
|
||||
res = None
|
||||
end_of_audio = False
|
||||
silence_buffer = None
|
||||
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
if res.get("end", 0) > res.get("start", 0):
|
||||
end_of_audio = True
|
||||
elif self.silence: #end of silence
|
||||
self.silence = False
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
|
||||
if silence_buffer:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_array.copy())
|
||||
|
||||
self.silence_duration = 0.0
|
||||
|
||||
if end_of_audio:
|
||||
self.silence = True
|
||||
self.start_silence = time()
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -5,9 +5,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
@@ -18,16 +15,7 @@ args = parse_args()
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
#to remove after 0.2.8
|
||||
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
|
||||
logger.warning(f"""
|
||||
{'='*50}
|
||||
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
|
||||
{'='*50}
|
||||
""")
|
||||
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
@@ -42,8 +30,6 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
@@ -54,7 +40,7 @@ async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json(response.to_dict())
|
||||
# when the results_generator finishes it means all audio has been processed
|
||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
@@ -72,6 +58,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
)
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send config to client: {e}")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
@@ -127,6 +118,8 @@ def main():
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
if args.forwarded_allow_ips:
|
||||
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
|
||||
|
||||
uvicorn.run(**uvicorn_kwargs)
|
||||
|
||||
|
||||
@@ -4,10 +4,15 @@ try:
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||
from whisperlivekit.warmup import warmup_asr, warmup_online
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
k: v for k, v in kwargs.items() if k in _dict
|
||||
})
|
||||
return _dict
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -21,66 +26,48 @@ class TranscriptionEngine:
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
defaults = {
|
||||
global_params = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"backend": "faster-whisper",
|
||||
"target_language": "",
|
||||
"vac": True,
|
||||
"vac_chunk_size": 0.04,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"forwarded_allow_ips": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
# whisperstreaming params:
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
# simulstreaming params:
|
||||
"disable_fast_encoder": False,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 20.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"pcm_input": False,
|
||||
"disable_punctuation_split" : False,
|
||||
"diarization_backend": "sortformer",
|
||||
# diart params:
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
global_params = update_with_kwargs(global_params, kwargs)
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
transcription_common_params = {
|
||||
"backend": "simulstreaming",
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.5,
|
||||
"model_size": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
}
|
||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||
|
||||
if transcription_common_params['model_size'].endswith(".en"):
|
||||
transcription_common_params["lan"] = "en"
|
||||
if 'no_transcription' in kwargs:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
global_params['transcription'] = not global_params['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
global_params['vad'] = not kwargs['no_vad']
|
||||
if 'no_vac' in kwargs:
|
||||
config_dict['vac'] = not kwargs['no_vac']
|
||||
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
global_params['vac'] = not kwargs['no_vac']
|
||||
|
||||
if 'language' in kwargs:
|
||||
config_dict['lan'] = kwargs['language']
|
||||
config_dict.pop('language', None)
|
||||
|
||||
self.args = Namespace(**config_dict)
|
||||
self.args = Namespace(**{**global_params, **transcription_common_params})
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
@@ -94,76 +81,96 @@ class TranscriptionEngine:
|
||||
if self.args.transcription:
|
||||
if self.args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
self.tokenizer = None
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
|
||||
if hasattr(self.args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||
|
||||
# Add segment_length from min_chunk_size
|
||||
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
|
||||
simulstreaming_kwargs['task'] = self.args.task
|
||||
|
||||
size = self.args.model
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 20.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
self.tokenizer = None
|
||||
self.asr = SimulStreamingASR(
|
||||
modelsize=size,
|
||||
lan=self.args.lan,
|
||||
cache_dir=getattr(self.args, 'model_cache_dir', None),
|
||||
model_dir=getattr(self.args, 'model_dir', None),
|
||||
**simulstreaming_kwargs
|
||||
**transcription_common_params, **simulstreaming_params
|
||||
)
|
||||
|
||||
else:
|
||||
self.asr, self.tokenizer = backend_factory(self.args)
|
||||
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||
|
||||
whisperstreaming_params = {
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
}
|
||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||
|
||||
self.asr = backend_factory(
|
||||
**transcription_common_params, **whisperstreaming_params
|
||||
)
|
||||
|
||||
if self.args.diarization:
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
diart_params = {
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
diart_params = update_with_kwargs(diart_params, kwargs)
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
**diart_params
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
translation_params = {
|
||||
"nllb_backend": "ctranslate2",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
|
||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
def online_factory(args, asr):
|
||||
if args.backend == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
logfile=logfile,
|
||||
)
|
||||
# warmup_online(online, args.warmup_file)
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
online = OnlineASRProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
online = OnlineASRProcessor(asr)
|
||||
return online
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
if args.diarization_backend == "diart":
|
||||
online = diarization_backend
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||
|
||||
if args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
return online
|
||||
|
||||
|
||||
|
||||
def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
@@ -60,11 +60,15 @@ class SortformerDiarization:
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.diar_model.to(torch.device("cuda"))
|
||||
logger.info("Using CUDA for Sortformer model")
|
||||
else:
|
||||
logger.info("Using CPU for Sortformer model")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.diar_model.to(device)
|
||||
|
||||
## to test
|
||||
# for name, param in self.diar_model.named_parameters():
|
||||
# if param.device != device:
|
||||
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||
|
||||
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||
|
||||
self.diar_model.sortformer_modules.chunk_len = 10
|
||||
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||
@@ -106,6 +110,7 @@ class SortformerDiarizationOnline:
|
||||
features=128,
|
||||
pad_to=0
|
||||
)
|
||||
self.audio2mel.to(self.diar_model.device)
|
||||
|
||||
self.chunk_duration_seconds = (
|
||||
self.diar_model.sortformer_modules.chunk_len *
|
||||
@@ -186,22 +191,25 @@ class SortformerDiarizationOnline:
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:]
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
@@ -209,7 +217,7 @@ class SortformerDiarizationOnline:
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
@@ -281,6 +289,7 @@ class SortformerDiarizationOnline:
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
Last speaker_segment
|
||||
"""
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
@@ -7,11 +7,12 @@ import contextlib
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ERROR_INSTALL_INSTRUCTIONS = """
|
||||
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||
{'='*50}
|
||||
FFmpeg is not installed or not found in your system's PATH.
|
||||
Please install FFmpeg to enable audio processing.
|
||||
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||
|
||||
Installation instructions:
|
||||
If you want to install FFmpeg:
|
||||
|
||||
# Ubuntu/Debian:
|
||||
sudo apt update && sudo apt install ffmpeg
|
||||
@@ -25,6 +26,7 @@ brew install ffmpeg
|
||||
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||
|
||||
After installation, please restart the application.
|
||||
{'='*50}
|
||||
"""
|
||||
|
||||
class FFmpegState(Enum):
|
||||
@@ -183,6 +185,8 @@ class FFmpegManager:
|
||||
async def _drain_stderr(self):
|
||||
try:
|
||||
while True:
|
||||
if not self.process or not self.process.stderr:
|
||||
break
|
||||
line = await self.process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
@@ -190,4 +194,4 @@ class FFmpegManager:
|
||||
except asyncio.CancelledError:
|
||||
logger.info("FFmpeg stderr drain task cancelled.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
|
||||
@@ -20,7 +20,7 @@ def parse_args():
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
If empty, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -72,6 +72,12 @@ def parse_args():
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-punctuation-split",
|
||||
action="store_true",
|
||||
help="Disable the split parameter.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
@@ -83,6 +89,7 @@ def parse_args():
|
||||
"--model",
|
||||
type=str,
|
||||
default="small",
|
||||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
@@ -103,6 +110,7 @@ def parse_args():
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
dest='lan',
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -112,6 +120,15 @@ def parse_args():
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-language",
|
||||
type=str,
|
||||
default="",
|
||||
dest="target_language",
|
||||
help="Target language for translation. Not functional yet.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
@@ -158,7 +175,13 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
|
||||
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
|
||||
parser.add_argument(
|
||||
"--pcm-input",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||
)
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
@@ -169,6 +192,13 @@ def parse_args():
|
||||
dest="disable_fast_encoder",
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--custom-alignment-heads",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
@@ -260,13 +290,27 @@ def parse_args():
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--preloaded_model_count",
|
||||
"--preload-model-count",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="preloaded_model_count",
|
||||
dest="preload_model_count",
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="ctranslate2",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
|
||||
@@ -39,7 +39,7 @@ def blank_to_silence(tokens):
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
@@ -77,15 +77,9 @@ def no_token_to_silence(tokens):
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
if not tokens:
|
||||
return [], buffer_transcription, buffer_diarization
|
||||
def ends_with_silence(tokens, current_time, vac_detected_silence):
|
||||
last_token = tokens[-1]
|
||||
if tokens and current_time and (
|
||||
current_time - last_token.end >= END_SILENCE_DURATION
|
||||
or
|
||||
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||
):
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
@@ -97,14 +91,14 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
|
||||
probability=0.95
|
||||
)
|
||||
)
|
||||
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
||||
buffer_diarization = ""
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
def handle_silences(tokens, current_time, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
tokens = ends_with_silence(tokens, current_time, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?'}
|
||||
CHECK_AROUND = 4
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.text.strip() in PUNCTUATION_MARKS:
|
||||
if token.is_punctuation():
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -34,105 +30,131 @@ def next_speaker_change(i, tokens, speaker):
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
speaker,
|
||||
last_end_diarized,
|
||||
debug_info = ""
|
||||
):
|
||||
return {
|
||||
"speaker": int(speaker),
|
||||
"text": token.text + debug_info,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
}
|
||||
return Line(
|
||||
speaker = token.corrected_speaker,
|
||||
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
|
||||
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
|
||||
if token.text:
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
def append_token_to_last_line(lines, sep, token):
|
||||
if not lines:
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
|
||||
def format_output(state, silence, current_time, diarization, debug):
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
def format_output(state, silence, current_time, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
previous_speaker = 1
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
tokens = handle_silences(tokens, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
|
||||
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
if diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
debug_info = ""
|
||||
if debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
|
||||
if not lines:
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
continue
|
||||
for i, token in enumerate(tokens[last_validated_token:]):
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
previous_speaker = lines[-1]['speaker']
|
||||
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if speaker != previous_speaker:
|
||||
# perfect, diarization perfectly aligned
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
last_punctuation, next_punctuation = None, None
|
||||
continue
|
||||
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
|
||||
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
|
||||
else:
|
||||
# No speaker change to come
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
continue
|
||||
|
||||
# if token.end > end_attributed_speaker and token.speaker != -2:
|
||||
# if tokens[-1].speaker == -2: #if it finishes by a silence, we want to append the undiarized text to the last speaker.
|
||||
# token.corrected_speaker = previous_speaker
|
||||
# else:
|
||||
# undiarized_text.append(token.text)
|
||||
# continue
|
||||
# else:
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if token.speaker != previous_speaker:
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
last_punctuation = None
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != previous_speaker:
|
||||
if not (speaker == -2 or previous_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
previous_speaker = token.corrected_speaker
|
||||
|
||||
if speaker != previous_speaker:
|
||||
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
continue
|
||||
elif next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
continue
|
||||
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
|
||||
# lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
pass
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
previous_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if int(token.corrected_speaker) != int(previous_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
previous_speaker = token.corrected_speaker
|
||||
|
||||
if lines and translated_segments:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
|
||||
if state.buffer_transcription and lines:
|
||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||
|
||||
return lines, undiarized_text
|
||||
|
||||
@@ -3,12 +3,11 @@ import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
|
||||
import os
|
||||
import gc
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +21,10 @@ try:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print(f"""{"="*50}
|
||||
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
||||
{"="*50}""")
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
@@ -42,13 +45,11 @@ class SimulStreamingOnlineProcessor:
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
warmup_file=None
|
||||
):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
@@ -77,7 +78,7 @@ class SimulStreamingOnlineProcessor:
|
||||
else:
|
||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.global_time_offset += silence_duration + offset
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
@@ -89,63 +90,15 @@ class SimulStreamingOnlineProcessor:
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def get_buffer(self):
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
text='',
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, tokens, generation):
|
||||
"""
|
||||
generate timestamped text from tokens and generation data.
|
||||
|
||||
args:
|
||||
tokens: List of tokens to process
|
||||
generation: Dictionary containing generation progress and optionally results
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.global_time_offset = change_speaker.start
|
||||
|
||||
returns:
|
||||
List of tuples containing (start_time, end_time, word) for each word
|
||||
"""
|
||||
FRAME_DURATION = 0.02
|
||||
if "result" in generation:
|
||||
split_words = generation["result"]["split_words"]
|
||||
split_tokens = generation["result"]["split_tokens"]
|
||||
else:
|
||||
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
||||
progress = generation["progress"]
|
||||
frames = [p["most_attended_frames"][0] for p in progress]
|
||||
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
|
||||
tokens_queue = tokens.copy()
|
||||
timestamped_words = []
|
||||
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# start_frame = None
|
||||
# end_frame = None
|
||||
for expected_token in word_tokens:
|
||||
if not tokens_queue or not frames:
|
||||
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
|
||||
|
||||
actual_token = tokens_queue.pop(0)
|
||||
current_frame = frames.pop(0)
|
||||
current_timestamp = absolute_timestamps.pop(0)
|
||||
if actual_token != expected_token:
|
||||
raise ValueError(
|
||||
f"Token mismatch: expected '{expected_token}', "
|
||||
f"got '{actual_token}' at frame {current_frame}"
|
||||
)
|
||||
# if start_frame is None:
|
||||
# start_frame = current_frame
|
||||
# end_frame = current_frame
|
||||
# start_time = start_frame * FRAME_DURATION
|
||||
# end_time = end_frame * FRAME_DURATION
|
||||
start_time = current_timestamp
|
||||
end_time = current_timestamp + 0.1
|
||||
timestamp_entry = (start_time, end_time, word)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
|
||||
return timestamped_words
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
@@ -154,47 +107,14 @@ class SimulStreamingOnlineProcessor:
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
tokens, generation_progress = self.model.infer(is_last=is_last)
|
||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
new_tokens = []
|
||||
for ts_word in ts_words:
|
||||
|
||||
start, end, word = ts_word
|
||||
token = ASRToken(
|
||||
start=start,
|
||||
end=end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
# identical_tokens = 0
|
||||
# n_new_tokens = len(new_tokens)
|
||||
# if n_new_tokens:
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
|
||||
# if token in self.committed:
|
||||
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
|
||||
# if pos:
|
||||
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
|
||||
# commited_segment = self.committed[i:i+n_new_tokens]
|
||||
# if commited_segment == new_tokens:
|
||||
# identical_segments +=1
|
||||
# if identical_tokens >= TOO_MANY_REPETITIONS:
|
||||
# logger.warning('Too many repetition, model is stuck. Load a new one')
|
||||
# self.committed = self.committed[:i]
|
||||
# self.load_new_backend()
|
||||
# return [], self.end
|
||||
|
||||
# pos = self.committed.rindex(token)
|
||||
|
||||
|
||||
|
||||
return new_tokens, self.end
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
@@ -223,32 +143,20 @@ class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
logger.warning(SIMULSTREAMING_LICENSE)
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
|
||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||
self.beams = kwargs.get('beams', 1)
|
||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||
self.task = kwargs.get('task', 'transcribe')
|
||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||
self.never_fire = kwargs.get('never_fire', False)
|
||||
self.init_prompt = kwargs.get('init_prompt', None)
|
||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
self.warmup_file = kwargs.get('warmup_file', None)
|
||||
self.preload_model_count = kwargs.get('preload_model_count', 1)
|
||||
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.decoder_type is None:
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None:
|
||||
if self.model_dir is not None:
|
||||
self.model_path = self.model_dir
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
@@ -263,13 +171,13 @@ class SimulStreamingASR():
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||
self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt')
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
segment_length=self.segment_length,
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.original_language,
|
||||
language=self.lan,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
@@ -288,8 +196,12 @@ class SimulStreamingASR():
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
if self.model_dir:
|
||||
self.model_name = self.model_dir
|
||||
self.model_path = None
|
||||
else:
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
@@ -311,7 +223,12 @@ class SimulStreamingASR():
|
||||
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
|
||||
whisper_model = load_model(
|
||||
name=self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
@@ -327,7 +244,7 @@ class SimulStreamingASR():
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
|
||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
@@ -360,4 +277,4 @@ class SimulStreamingASR():
|
||||
"""
|
||||
Warmup is done directly in load_model
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
@@ -18,6 +19,7 @@ from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
import numpy as np
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
from .generation_progress import *
|
||||
|
||||
DEC_PAD = 50257
|
||||
@@ -40,12 +42,6 @@ else:
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
# - translation support
|
||||
# - beam search
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -61,6 +57,8 @@ class PaddedAlignAttWhisper:
|
||||
self.model = loaded_model
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
@@ -68,7 +66,7 @@ class PaddedAlignAttWhisper:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
@@ -76,7 +74,10 @@ class PaddedAlignAttWhisper:
|
||||
)
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
@@ -151,6 +152,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.first_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
@@ -172,7 +174,6 @@ class PaddedAlignAttWhisper:
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
def remove_hooks(self):
|
||||
print('remove hook')
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
|
||||
@@ -259,7 +260,6 @@ class PaddedAlignAttWhisper:
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
@@ -381,11 +381,11 @@ class PaddedAlignAttWhisper:
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return [], {}
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
return [], {}
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
@@ -393,88 +393,77 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
|
||||
# logger.debug("Resetting tokenizer to auto for new sentence.")
|
||||
# self.create_tokenizer(None)
|
||||
# self.detected_language = None
|
||||
# self.init_tokens()
|
||||
# self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
|
||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||
device = 'cpu'
|
||||
elif self.fw_encoder:
|
||||
audio_length_seconds = len(input_segments) / 16000
|
||||
content_mel_len = int(audio_length_seconds * 100)//2
|
||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
|
||||
device = 'cpu'
|
||||
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
|
||||
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
|
||||
else:
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
device=self.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
device = mel.device
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# logger.debug("mel ")
|
||||
if self.cfg.language == "auto" and self.detected_language is None:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
#self.tokenizer.language = top_lan
|
||||
#self.tokenizer.__post_init__()
|
||||
self.create_tokenizer(top_lan)
|
||||
self.detected_language = top_lan
|
||||
self.init_tokens()
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
#
|
||||
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
|
||||
####################### Decoding loop
|
||||
logger.info("Decoding loop starts\n")
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
completed = False
|
||||
# punctuation_stop = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
most_attended_frame = None
|
||||
|
||||
token_len_before_decoding = current_tokens.shape[1]
|
||||
|
||||
generation_progress = []
|
||||
generation = {
|
||||
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
|
||||
"token_len_before_decoding": token_len_before_decoding,
|
||||
#"fire_detected": fire_detected,
|
||||
"frames_len": content_mel_len,
|
||||
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
|
||||
|
||||
# to be filled later
|
||||
"logits_starting": None,
|
||||
|
||||
# to be filled later
|
||||
"no_speech_prob": None,
|
||||
"no_speech": False,
|
||||
|
||||
# to be filled in the loop
|
||||
"progress": generation_progress,
|
||||
}
|
||||
l_absolute_timestamps = []
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
generation_progress_loop = []
|
||||
|
||||
if new_segment:
|
||||
tokens_for_logits = current_tokens
|
||||
@@ -483,50 +472,26 @@ class PaddedAlignAttWhisper:
|
||||
tokens_for_logits = current_tokens[:,-1:]
|
||||
|
||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||
if new_segment:
|
||||
generation["logits_starting"] = Logits(logits[:,:,:])
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
generation["no_speech_prob"] = no_speech_probs[0]
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
generation["no_speech"] = True
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
|
||||
|
||||
# supress blank tokens only at the beginning of the segment
|
||||
if new_segment:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
new_segment = False
|
||||
self.suppress_tokens(logits)
|
||||
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
|
||||
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
|
||||
|
||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
|
||||
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
|
||||
generation_progress_loop.append(("completed",completed))
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
|
||||
# if self.decoder_type == "beam":
|
||||
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
|
||||
|
||||
# logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
# idx = 0
|
||||
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
|
||||
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
|
||||
# if completed:
|
||||
# self.debug_print_tokens(current_tokens)
|
||||
|
||||
# logger.debug("decode stopped because decoder completed")
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
@@ -545,30 +510,24 @@ class PaddedAlignAttWhisper:
|
||||
t = torch.cat(mat, dim=1)
|
||||
tmp.append(t)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
|
||||
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
||||
|
||||
# Calculate absolute timestamps accounting for cumulative offset
|
||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
|
||||
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||
|
||||
|
||||
generation_progress.append(dict(generation_progress_loop))
|
||||
logger.debug("current tokens" + str(current_tokens.shape))
|
||||
if completed:
|
||||
# # stripping the last token, the eot
|
||||
@@ -606,66 +565,53 @@ class PaddedAlignAttWhisper:
|
||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||
))
|
||||
|
||||
# for k,v in generation.items():
|
||||
# print(k,v,file=sys.stderr)
|
||||
# for x in generation_progress:
|
||||
# for y in x.items():
|
||||
# print("\t\t",*y,file=sys.stderr)
|
||||
# print("\t","----", file=sys.stderr)
|
||||
# print("\t", "end of generation_progress_loop", file=sys.stderr)
|
||||
# sys.exit(1)
|
||||
####################### End of decoding loop
|
||||
|
||||
logger.info("End of decoding loop")
|
||||
|
||||
# if attn_of_alignment_heads is not None:
|
||||
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
|
||||
|
||||
# # Lets' now consider only the top hypothesis in the beam search
|
||||
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
|
||||
|
||||
# # debug print: how is the new token attended?
|
||||
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
|
||||
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
|
||||
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
|
||||
# logger.debug("no token generated")
|
||||
# else: # it is, and the max attention is:
|
||||
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
|
||||
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
|
||||
|
||||
|
||||
# let's now operate only with the top beam hypothesis
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
if fire_detected or is_last:
|
||||
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
# going to truncate the tokens after the last space
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
|
||||
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
|
||||
|
||||
# text_to_split = self.tokenizer.decode(tokens_to_split)
|
||||
# logger.debug(f"text_to_split: {text_to_split}")
|
||||
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
|
||||
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
|
||||
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
|
||||
if len(split_words) > 1:
|
||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||
else:
|
||||
new_hypothesis = []
|
||||
|
||||
|
||||
### new hypothesis
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.model.device,
|
||||
device=self.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
# TODO: test if this is redundant or not
|
||||
# ret = ret[ret<DEC_PAD]
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
return new_hypothesis, generation
|
||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||
self.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except:
|
||||
pass
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=current_timestamp,
|
||||
end=current_timestamp + 0.1,
|
||||
text= word,
|
||||
probability=0.95,
|
||||
speaker=self.speaker,
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
return timestamped_words
|
||||
@@ -105,7 +105,8 @@ def load_model(
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only=False
|
||||
decoder_only=False,
|
||||
custom_alignment_heads=None
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
@@ -135,15 +136,17 @@ def load_model(
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
|
||||
@@ -1,20 +1,57 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from datetime import timedelta
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedText:
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
start: Optional[float] = 0
|
||||
end: Optional[float] = 0
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
is_dummy: Optional[bool] = False
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def is_punctuation(self):
|
||||
return self.text.strip() in PUNCTUATION_MARKS
|
||||
|
||||
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||
return not (self.end <= other.start or other.end <= self.start)
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
|
||||
@dataclass
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_time(self, time: float) -> bool:
|
||||
return self.start <= time <= self.end
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
corrected_speaker: Optional[int] = -1
|
||||
validated_speaker: bool = False
|
||||
validated_text: bool = False
|
||||
validated_language: bool = False
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
@@ -22,7 +59,28 @@ class Sentence(TimedText):
|
||||
|
||||
@dataclass
|
||||
class Transcript(TimedText):
|
||||
pass
|
||||
"""
|
||||
represents a concatenation of several ASRToken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> "Transcript":
|
||||
sep = sep if sep is not None else ' '
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return cls(start, end, text, probability=probability)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
@@ -31,6 +89,96 @@ class SpeakerSegment(TimedText):
|
||||
"""
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
def approximate_cut_at(self, cut_time):
|
||||
"""
|
||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||
"""
|
||||
if not self.text or not self.contains_time(cut_time):
|
||||
return self, None
|
||||
|
||||
words = self.text.split()
|
||||
num_words = len(words)
|
||||
if num_words == 0:
|
||||
return self, None
|
||||
|
||||
duration_per_word = self.duration() / num_words
|
||||
|
||||
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||
|
||||
if cut_word_index >= num_words:
|
||||
cut_word_index = num_words -1
|
||||
|
||||
text0 = " ".join(words[:cut_word_index])
|
||||
text1 = " ".join(words[cut_word_index:])
|
||||
|
||||
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||
|
||||
return segment0, segment1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
duration: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'start': format_time(self.start),
|
||||
'end': format_time(self.end),
|
||||
}
|
||||
if self.translation:
|
||||
_dict['translation'] = self.translation
|
||||
if self.detected_language:
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
status: str = ''
|
||||
error: str = ''
|
||||
lines: list[Line] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
'buffer_diarization': self.buffer_diarization,
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
}
|
||||
if self.error:
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
start: int
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list
|
||||
last_validated_token: int
|
||||
translated_segments: list
|
||||
buffer_transcription: str
|
||||
end_buffer: float
|
||||
end_attributed_speaker: float
|
||||
remaining_time_transcription: float
|
||||
remaining_time_diarization: float
|
||||
|
||||
0
whisperlivekit/translation/__init__.py
Normal file
182
whisperlivekit/translation/mapping_languages.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
adapted from https://store.crowdin.com/custom-mt
|
||||
"""
|
||||
|
||||
LANGUAGES = [
|
||||
{"name": "Afrikaans", "nllb": "afr_Latn", "crowdin": "af"},
|
||||
{"name": "Akan", "nllb": "aka_Latn", "crowdin": "ak"},
|
||||
{"name": "Amharic", "nllb": "amh_Ethi", "crowdin": "am"},
|
||||
{"name": "Assamese", "nllb": "asm_Beng", "crowdin": "as"},
|
||||
{"name": "Asturian", "nllb": "ast_Latn", "crowdin": "ast"},
|
||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "crowdin": "ba"},
|
||||
{"name": "Bambara", "nllb": "bam_Latn", "crowdin": "bm"},
|
||||
{"name": "Balinese", "nllb": "ban_Latn", "crowdin": "ban"},
|
||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "crowdin": "be"},
|
||||
{"name": "Bengali", "nllb": "ben_Beng", "crowdin": "bn"},
|
||||
{"name": "Bosnian", "nllb": "bos_Latn", "crowdin": "bs"},
|
||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "crowdin": "bg"},
|
||||
{"name": "Catalan", "nllb": "cat_Latn", "crowdin": "ca"},
|
||||
{"name": "Cebuano", "nllb": "ceb_Latn", "crowdin": "ceb"},
|
||||
{"name": "Czech", "nllb": "ces_Latn", "crowdin": "cs"},
|
||||
{"name": "Welsh", "nllb": "cym_Latn", "crowdin": "cy"},
|
||||
{"name": "Danish", "nllb": "dan_Latn", "crowdin": "da"},
|
||||
{"name": "German", "nllb": "deu_Latn", "crowdin": "de"},
|
||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "crowdin": "dz"},
|
||||
{"name": "Greek", "nllb": "ell_Grek", "crowdin": "el"},
|
||||
{"name": "English", "nllb": "eng_Latn", "crowdin": "en"},
|
||||
{"name": "Esperanto", "nllb": "epo_Latn", "crowdin": "eo"},
|
||||
{"name": "Estonian", "nllb": "est_Latn", "crowdin": "et"},
|
||||
{"name": "Basque", "nllb": "eus_Latn", "crowdin": "eu"},
|
||||
{"name": "Ewe", "nllb": "ewe_Latn", "crowdin": "ee"},
|
||||
{"name": "Faroese", "nllb": "fao_Latn", "crowdin": "fo"},
|
||||
{"name": "Fijian", "nllb": "fij_Latn", "crowdin": "fj"},
|
||||
{"name": "Finnish", "nllb": "fin_Latn", "crowdin": "fi"},
|
||||
{"name": "French", "nllb": "fra_Latn", "crowdin": "fr"},
|
||||
{"name": "Friulian", "nllb": "fur_Latn", "crowdin": "fur-IT"},
|
||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "crowdin": "gd"},
|
||||
{"name": "Irish", "nllb": "gle_Latn", "crowdin": "ga-IE"},
|
||||
{"name": "Galician", "nllb": "glg_Latn", "crowdin": "gl"},
|
||||
{"name": "Guarani", "nllb": "grn_Latn", "crowdin": "gn"},
|
||||
{"name": "Gujarati", "nllb": "guj_Gujr", "crowdin": "gu-IN"},
|
||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "crowdin": "ht"},
|
||||
{"name": "Hausa", "nllb": "hau_Latn", "crowdin": "ha"},
|
||||
{"name": "Hebrew", "nllb": "heb_Hebr", "crowdin": "he"},
|
||||
{"name": "Hindi", "nllb": "hin_Deva", "crowdin": "hi"},
|
||||
{"name": "Croatian", "nllb": "hrv_Latn", "crowdin": "hr"},
|
||||
{"name": "Hungarian", "nllb": "hun_Latn", "crowdin": "hu"},
|
||||
{"name": "Armenian", "nllb": "hye_Armn", "crowdin": "hy-AM"},
|
||||
{"name": "Igbo", "nllb": "ibo_Latn", "crowdin": "ig"},
|
||||
{"name": "Indonesian", "nllb": "ind_Latn", "crowdin": "id"},
|
||||
{"name": "Icelandic", "nllb": "isl_Latn", "crowdin": "is"},
|
||||
{"name": "Italian", "nllb": "ita_Latn", "crowdin": "it"},
|
||||
{"name": "Javanese", "nllb": "jav_Latn", "crowdin": "jv"},
|
||||
{"name": "Japanese", "nllb": "jpn_Jpan", "crowdin": "ja"},
|
||||
{"name": "Kabyle", "nllb": "kab_Latn", "crowdin": "kab"},
|
||||
{"name": "Kannada", "nllb": "kan_Knda", "crowdin": "kn"},
|
||||
{"name": "Georgian", "nllb": "kat_Geor", "crowdin": "ka"},
|
||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "crowdin": "kk"},
|
||||
{"name": "Khmer", "nllb": "khm_Khmr", "crowdin": "km"},
|
||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "crowdin": "rw"},
|
||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "crowdin": "ky"},
|
||||
{"name": "Korean", "nllb": "kor_Hang", "crowdin": "ko"},
|
||||
{"name": "Lao", "nllb": "lao_Laoo", "crowdin": "lo"},
|
||||
{"name": "Ligurian", "nllb": "lij_Latn", "crowdin": "lij"},
|
||||
{"name": "Limburgish", "nllb": "lim_Latn", "crowdin": "li"},
|
||||
{"name": "Lingala", "nllb": "lin_Latn", "crowdin": "ln"},
|
||||
{"name": "Lithuanian", "nllb": "lit_Latn", "crowdin": "lt"},
|
||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "crowdin": "lb"},
|
||||
{"name": "Maithili", "nllb": "mai_Deva", "crowdin": "mai"},
|
||||
{"name": "Malayalam", "nllb": "mal_Mlym", "crowdin": "ml-IN"},
|
||||
{"name": "Marathi", "nllb": "mar_Deva", "crowdin": "mr"},
|
||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "crowdin": "mk"},
|
||||
{"name": "Maltese", "nllb": "mlt_Latn", "crowdin": "mt"},
|
||||
{"name": "Mossi", "nllb": "mos_Latn", "crowdin": "mos"},
|
||||
{"name": "Maori", "nllb": "mri_Latn", "crowdin": "mi"},
|
||||
{"name": "Burmese", "nllb": "mya_Mymr", "crowdin": "my"},
|
||||
{"name": "Dutch", "nllb": "nld_Latn", "crowdin": "nl"},
|
||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "crowdin": "nn-NO"},
|
||||
{"name": "Nepali", "nllb": "npi_Deva", "crowdin": "ne-NP"},
|
||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "crowdin": "nso"},
|
||||
{"name": "Occitan", "nllb": "oci_Latn", "crowdin": "oc"},
|
||||
{"name": "Odia", "nllb": "ory_Orya", "crowdin": "or"},
|
||||
{"name": "Papiamento", "nllb": "pap_Latn", "crowdin": "pap"},
|
||||
{"name": "Polish", "nllb": "pol_Latn", "crowdin": "pl"},
|
||||
{"name": "Portuguese", "nllb": "por_Latn", "crowdin": "pt-PT"},
|
||||
{"name": "Dari", "nllb": "prs_Arab", "crowdin": "fa-AF"},
|
||||
{"name": "Romanian", "nllb": "ron_Latn", "crowdin": "ro"},
|
||||
{"name": "Rundi", "nllb": "run_Latn", "crowdin": "rn"},
|
||||
{"name": "Russian", "nllb": "rus_Cyrl", "crowdin": "ru"},
|
||||
{"name": "Sango", "nllb": "sag_Latn", "crowdin": "sg"},
|
||||
{"name": "Sanskrit", "nllb": "san_Deva", "crowdin": "sa"},
|
||||
{"name": "Santali", "nllb": "sat_Olck", "crowdin": "sat"},
|
||||
{"name": "Sinhala", "nllb": "sin_Sinh", "crowdin": "si-LK"},
|
||||
{"name": "Slovak", "nllb": "slk_Latn", "crowdin": "sk"},
|
||||
{"name": "Slovenian", "nllb": "slv_Latn", "crowdin": "sl"},
|
||||
{"name": "Shona", "nllb": "sna_Latn", "crowdin": "sn"},
|
||||
{"name": "Sindhi", "nllb": "snd_Arab", "crowdin": "sd"},
|
||||
{"name": "Somali", "nllb": "som_Latn", "crowdin": "so"},
|
||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "crowdin": "st"},
|
||||
{"name": "Spanish", "nllb": "spa_Latn", "crowdin": "es-ES"},
|
||||
{"name": "Sardinian", "nllb": "srd_Latn", "crowdin": "sc"},
|
||||
{"name": "Swati", "nllb": "ssw_Latn", "crowdin": "ss"},
|
||||
{"name": "Sundanese", "nllb": "sun_Latn", "crowdin": "su"},
|
||||
{"name": "Swedish", "nllb": "swe_Latn", "crowdin": "sv-SE"},
|
||||
{"name": "Swahili", "nllb": "swh_Latn", "crowdin": "sw"},
|
||||
{"name": "Tamil", "nllb": "tam_Taml", "crowdin": "ta"},
|
||||
{"name": "Tatar", "nllb": "tat_Cyrl", "crowdin": "tt-RU"},
|
||||
{"name": "Telugu", "nllb": "tel_Telu", "crowdin": "te"},
|
||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "crowdin": "tg"},
|
||||
{"name": "Tagalog", "nllb": "tgl_Latn", "crowdin": "tl"},
|
||||
{"name": "Thai", "nllb": "tha_Thai", "crowdin": "th"},
|
||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "crowdin": "ti"},
|
||||
{"name": "Tswana", "nllb": "tsn_Latn", "crowdin": "tn"},
|
||||
{"name": "Tsonga", "nllb": "tso_Latn", "crowdin": "ts"},
|
||||
{"name": "Turkmen", "nllb": "tuk_Latn", "crowdin": "tk"},
|
||||
{"name": "Turkish", "nllb": "tur_Latn", "crowdin": "tr"},
|
||||
{"name": "Uyghur", "nllb": "uig_Arab", "crowdin": "ug"},
|
||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "crowdin": "uk"},
|
||||
{"name": "Venetian", "nllb": "vec_Latn", "crowdin": "vec"},
|
||||
{"name": "Vietnamese", "nllb": "vie_Latn", "crowdin": "vi"},
|
||||
{"name": "Wolof", "nllb": "wol_Latn", "crowdin": "wo"},
|
||||
{"name": "Xhosa", "nllb": "xho_Latn", "crowdin": "xh"},
|
||||
{"name": "Yoruba", "nllb": "yor_Latn", "crowdin": "yo"},
|
||||
{"name": "Zulu", "nllb": "zul_Latn", "crowdin": "zu"},
|
||||
]
|
||||
|
||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NAME_TO_CROWDIN = {lang["name"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NLLB = {lang["crowdin"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NLLB_TO_CROWDIN = {lang["nllb"]: lang["crowdin"] for lang in LANGUAGES}
|
||||
CROWDIN_TO_NAME = {lang["crowdin"]: lang["name"] for lang in LANGUAGES}
|
||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||
|
||||
|
||||
def get_nllb_code(crowdin_code):
|
||||
return CROWDIN_TO_NLLB.get(crowdin_code, None)
|
||||
|
||||
|
||||
def get_crowdin_code(nllb_code):
|
||||
return NLLB_TO_CROWDIN.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_name_by_crowdin(crowdin_code):
|
||||
return CROWDIN_TO_NAME.get(crowdin_code)
|
||||
|
||||
|
||||
def get_language_name_by_nllb(nllb_code):
|
||||
return NLLB_TO_NAME.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_info(identifier, identifier_type="auto"):
|
||||
if identifier_type == "auto":
|
||||
for lang in LANGUAGES:
|
||||
if (lang["name"].lower() == identifier.lower() or
|
||||
lang["nllb"] == identifier or
|
||||
lang["crowdin"] == identifier):
|
||||
return lang
|
||||
elif identifier_type == "name":
|
||||
for lang in LANGUAGES:
|
||||
if lang["name"].lower() == identifier.lower():
|
||||
return lang
|
||||
elif identifier_type == "nllb":
|
||||
for lang in LANGUAGES:
|
||||
if lang["nllb"] == identifier:
|
||||
return lang
|
||||
elif identifier_type == "crowdin":
|
||||
for lang in LANGUAGES:
|
||||
if lang["crowdin"] == identifier:
|
||||
return lang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_languages():
|
||||
return [lang["name"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_nllb_codes():
|
||||
return [lang["nllb"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_crowdin_codes():
|
||||
return [lang["crowdin"] for lang in LANGUAGES]
|
||||
169
whisperlivekit/translation/translation.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import logging
|
||||
import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||
# sentence is not finished.
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
device: str
|
||||
tokenizer: dict = field(default_factory=dict)
|
||||
backend_type: str = 'ctranslate2'
|
||||
nllb_size: str = '600M'
|
||||
|
||||
def get_tokenizer(self, input_lang):
|
||||
if not self.tokenizer.get(input_lang, False):
|
||||
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
||||
f"facebook/nllb-200-distilled-{self.nllb_size}",
|
||||
src_lang=input_lang,
|
||||
clean_up_tokenization_spaces=True
|
||||
)
|
||||
return self.tokenizer[input_lang]
|
||||
|
||||
|
||||
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
||||
if nllb_backend=='ctranslate2':
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
elif nllb_backend=='transformers':
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}")
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
translation_model = TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer,
|
||||
backend_type=nllb_backend,
|
||||
device = device,
|
||||
nllb_size = nllb_size
|
||||
)
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
translation_model.get_tokenizer(src_lang)
|
||||
return translation_model
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.buffer:
|
||||
self.buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.buffer), len(results))):
|
||||
if self.buffer[i] != results[i]:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang, output_lang):
|
||||
if not input:
|
||||
return ""
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
else:
|
||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
return result
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
if self.input_languages[0] == 'auto':
|
||||
input_lang = tokens[0].detected_language
|
||||
else:
|
||||
input_lang = self.input_languages[0]
|
||||
|
||||
translated_text = self.translate(text,
|
||||
input_lang,
|
||||
self.output_languages[0]
|
||||
)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].is_punctuation():
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||
self.buffer = []
|
||||
self.validated += [self.translation_remaining]
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
input_lang = "en"
|
||||
|
||||
|
||||
test_string = """
|
||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||||
"""
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang], nllb_backend='ctranslate2')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
@@ -6,57 +6,46 @@ logger = logging.getLogger(__name__)
|
||||
def load_file(warmup_file=None, timeout=5):
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import librosa
|
||||
|
||||
|
||||
if warmup_file == "":
|
||||
logger.info(f"Skipping warmup.")
|
||||
return None
|
||||
|
||||
# Download JFK sample if not already present
|
||||
if warmup_file is None:
|
||||
# Download JFK sample if not already present
|
||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||
|
||||
if not os.path.exists(warmup_file):
|
||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||
print(f"Downloading warmup file from {jfk_url}")
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import socket
|
||||
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
start_time = time.time()
|
||||
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
try:
|
||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||
except (urllib.error.URLError, socket.timeout) as e:
|
||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
|
||||
f.write(r.read())
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup file download failed: {e}.")
|
||||
return None
|
||||
finally:
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
elif not warmup_file:
|
||||
return None
|
||||
|
||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||
|
||||
# Validate file and load
|
||||
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||
audio, _ = librosa.load(warmup_file, sr=16000)
|
||||
return audio
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load audio file: {e}")
|
||||
logger.warning(f"Failed to load warmup file: {e}")
|
||||
return None
|
||||
return audio
|
||||
|
||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
"""
|
||||
Warmup the ASR model by transcribing a short audio file.
|
||||
"""
|
||||
audio = load_file(warmup_file=None, timeout=5)
|
||||
audio = load_file(warmup_file=warmup_file, timeout=timeout)
|
||||
if audio is None:
|
||||
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
|
||||
return
|
||||
asr.transcribe(audio)
|
||||
logger.info("ASR model is warmed up")
|
||||
|
||||
def warmup_online(online, warmup_file=None, timeout=5):
|
||||
audio = load_file(warmup_file=None, timeout=5)
|
||||
online.warmup(audio)
|
||||
logger.warning("ASR is warmed up")
|
||||
logger.info("ASR model is warmed up.")
|
||||
@@ -72,12 +72,21 @@
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
html.is-extension
|
||||
{
|
||||
width: 350px;
|
||||
height: 500px;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 20px;
|
||||
margin: 0;
|
||||
text-align: center;
|
||||
background-color: var(--bg);
|
||||
color: var(--text);
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
/* Record button */
|
||||
@@ -168,9 +177,18 @@ body {
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 20px;
|
||||
margin-top: 15px;
|
||||
font-size: 16px;
|
||||
color: var(--text);
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.header-container {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background-color: var(--bg);
|
||||
z-index: 100;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
/* Settings */
|
||||
@@ -179,7 +197,14 @@ body {
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
margin-top: 20px;
|
||||
position: relative;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
@@ -189,6 +214,66 @@ body {
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
border: 1px solid var(--button-border);
|
||||
cursor: pointer;
|
||||
display: none;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle.active {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 10000px) {
|
||||
.settings-toggle {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: none;
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 18px;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.settings.visible {
|
||||
display: flex;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -297,9 +382,21 @@ label {
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 20px;
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
.transcript-container::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Transcript area */
|
||||
#linesTranscript {
|
||||
margin: 20px auto;
|
||||
margin: 0 auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
@@ -323,7 +420,7 @@ label {
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
@@ -335,7 +432,7 @@ label {
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 8px 8px 8px 8px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
@@ -345,9 +442,34 @@ label {
|
||||
color: var(--label-trans-text);
|
||||
}
|
||||
|
||||
.label_translation {
|
||||
background-color: var(--chip-bg);
|
||||
display: inline-flex;
|
||||
border-radius: 10px;
|
||||
padding: 4px 8px;
|
||||
margin-top: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--text);
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.lag-diarization-value {
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 10px;
|
||||
margin-left: 0px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
@@ -361,7 +483,6 @@ label {
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
@@ -406,12 +527,20 @@ label {
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
@media (max-width: 768px) {
|
||||
@media (max-width: 200px) {
|
||||
.header-container {
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
@@ -430,11 +559,15 @@ label {
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
body {
|
||||
margin: 10px;
|
||||
.header-container {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
@@ -457,4 +590,36 @@ label {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 10px;
|
||||
}
|
||||
}
|
||||
|
||||
.label_language {
|
||||
background-color: var(--chip-bg);
|
||||
margin-bottom: 0px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 8px;
|
||||
margin-left: 10px;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
|
||||
.speaker-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
margin-left: -5px;
|
||||
border-radius: 50%;
|
||||
font-size: 11px;
|
||||
line-height: 1;
|
||||
font-weight: 800;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
@@ -5,69 +5,75 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||
<link rel="stylesheet" href="live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="settings-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
<div class="header-container">
|
||||
<div class="settings-container">
|
||||
<div class="buttons-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||
<img src="web/src/settings.svg" alt="Settings" />
|
||||
</button>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
</div>
|
||||
|
||||
<div class="transcript-container">
|
||||
<div id="linesTranscript"></div>
|
||||
</div>
|
||||
|
||||
|
||||
<p id="status"></p>
|
||||
|
||||
<div id="linesTranscript"></div>
|
||||
|
||||
<script src="/web/live_transcription.js"></script>
|
||||
<script src="live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
|
||||
if (isExtension) {
|
||||
document.documentElement.classList.add('is-extension');
|
||||
}
|
||||
const isWebContext = !isExtension;
|
||||
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
@@ -12,6 +16,8 @@ let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let workletNode = null;
|
||||
let recorderWorker = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
@@ -20,6 +26,11 @@ let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
let serverUseAudioWorklet = null;
|
||||
let configReadyResolve;
|
||||
const configReady = new Promise((r) => (configReadyResolve = r));
|
||||
let outputAudioContext = null;
|
||||
let audioSource = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
@@ -35,6 +46,26 @@ const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
|
||||
const settingsToggle = document.getElementById("settingsToggle");
|
||||
const settingsDiv = document.querySelector(".settings");
|
||||
|
||||
// if (isExtension) {
|
||||
// chrome.runtime.onInstalled.addListener((details) => {
|
||||
// if (details.reason.search(/install/g) === -1) {
|
||||
// return;
|
||||
// }
|
||||
// chrome.tabs.create({
|
||||
// url: chrome.runtime.getURL("welcome.html"),
|
||||
// active: true
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
|
||||
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
|
||||
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
|
||||
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
@@ -146,10 +177,16 @@ function fmt1(x) {
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
// Default WebSocket URL computation
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port;
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
let host, port, protocol;
|
||||
port = 8000;
|
||||
if (isExtension) {
|
||||
host = "localhost";
|
||||
protocol = "ws";
|
||||
} else {
|
||||
host = window.location.hostname || "localhost";
|
||||
port = window.location.port;
|
||||
protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
}
|
||||
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||
|
||||
// Populate default caption and input
|
||||
@@ -226,6 +263,14 @@ function setupWebSocket() {
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "config") {
|
||||
serverUseAudioWorklet = !!data.useAudioWorklet;
|
||||
statusText.textContent = serverUseAudioWorklet
|
||||
? "Connected. Using AudioWorklet (PCM)."
|
||||
: "Connected. Using MediaRecorder (WebM).";
|
||||
if (configReadyResolve) configReadyResolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
@@ -293,7 +338,7 @@ function renderLinesWithBuffer(
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
status: current_status,
|
||||
@@ -316,19 +361,24 @@ function renderLinesWithBuffer(
|
||||
const linesHtml = (lines || [])
|
||||
.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.beg !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.beg} - ${item.end}`;
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
|
||||
if (item.detected_language) {
|
||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
@@ -365,6 +415,16 @@ function renderLinesWithBuffer(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (item.translation) {
|
||||
currentLineText += `
|
||||
<div>
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span>${item.translation}</span>
|
||||
</div>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
@@ -373,7 +433,10 @@ function renderLinesWithBuffer(
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
|
||||
const transcriptContainer = document.querySelector('.transcript-container');
|
||||
if (transcriptContainer) {
|
||||
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
@@ -435,11 +498,44 @@ async function startRecording() {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
|
||||
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
let stream;
|
||||
|
||||
// chromium extension. in the future, both chrome page audio and mic will be used
|
||||
if (isExtension) {
|
||||
try {
|
||||
stream = await new Promise((resolve, reject) => {
|
||||
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||
if (s) {
|
||||
resolve(s);
|
||||
} else {
|
||||
reject(new Error('Tab capture failed or not available'));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
audioSource = outputAudioContext.createMediaStreamSource(stream);
|
||||
audioSource.connect(outputAudioContext.destination);
|
||||
} catch (audioError) {
|
||||
console.warn('could not preserve system audio:', audioError);
|
||||
}
|
||||
|
||||
statusText.textContent = "Using tab audio capture.";
|
||||
} catch (tabError) {
|
||||
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
statusText.textContent = "Using microphone audio.";
|
||||
}
|
||||
} else if (isWebContext) {
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
}
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
@@ -447,13 +543,54 @@ async function startRecording() {
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data);
|
||||
if (serverUseAudioWorklet) {
|
||||
if (!audioContext.audioWorklet) {
|
||||
throw new Error("AudioWorklet is not supported in this browser");
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
|
||||
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
|
||||
microphone.connect(workletNode);
|
||||
|
||||
recorderWorker = new Worker("/web/recorder_worker.js");
|
||||
recorderWorker.postMessage({
|
||||
command: "init",
|
||||
config: {
|
||||
sampleRate: audioContext.sampleRate,
|
||||
},
|
||||
});
|
||||
|
||||
recorderWorker.onmessage = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
workletNode.port.onmessage = (e) => {
|
||||
const data = e.data;
|
||||
const ab = data instanceof ArrayBuffer ? data : data.buffer;
|
||||
recorderWorker.postMessage(
|
||||
{
|
||||
command: "record",
|
||||
buffer: ab,
|
||||
},
|
||||
[ab]
|
||||
);
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
} catch (e) {
|
||||
recorder = new MediaRecorder(stream);
|
||||
}
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
if (e.data && e.data.size > 0) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
}
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
@@ -492,10 +629,28 @@ async function stopRecording() {
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
try {
|
||||
recorder.stop();
|
||||
} catch (e) {
|
||||
}
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (recorderWorker) {
|
||||
recorderWorker.terminate();
|
||||
recorderWorker = null;
|
||||
}
|
||||
|
||||
if (workletNode) {
|
||||
try {
|
||||
workletNode.port.onmessage = null;
|
||||
} catch (e) {}
|
||||
try {
|
||||
workletNode.disconnect();
|
||||
} catch (e) {}
|
||||
workletNode = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
@@ -514,6 +669,16 @@ async function stopRecording() {
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (audioSource) {
|
||||
audioSource.disconnect();
|
||||
audioSource = null;
|
||||
}
|
||||
|
||||
if (outputAudioContext && outputAudioContext.state !== "closed") {
|
||||
outputAudioContext.close()
|
||||
outputAudioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
@@ -539,9 +704,11 @@ async function toggleRecording() {
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await configReady;
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await configReady;
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -563,7 +730,7 @@ function updateUI() {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
statusText.textContent = "";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
@@ -597,3 +764,40 @@ navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
settingsToggle.addEventListener("click", () => {
|
||||
settingsDiv.classList.toggle("visible");
|
||||
settingsToggle.classList.toggle("active");
|
||||
});
|
||||
|
||||
if (isExtension) {
|
||||
async function checkAndRequestPermissions() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
const permissionDisplay = document.getElementById("audioPermission");
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
|
||||
// if (micPermission.state !== "granted") {
|
||||
// chrome.tabs.create({ url: "welcome.html" });
|
||||
// }
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void checkAndRequestPermissions();
|
||||
}
|
||||
|
||||
16
whisperlivekit/web/pcm_worklet.js
Normal file
@@ -0,0 +1,16 @@
|
||||
class PCMForwarder extends AudioWorkletProcessor {
|
||||
process(inputs) {
|
||||
const input = inputs[0];
|
||||
if (input && input[0] && input[0].length) {
|
||||
// Forward mono channel (0). If multi-channel, downmixing can be added here.
|
||||
const channelData = input[0];
|
||||
const copy = new Float32Array(channelData.length);
|
||||
copy.set(channelData);
|
||||
this.port.postMessage(copy, [copy.buffer]);
|
||||
}
|
||||
// Keep processor alive
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
registerProcessor('pcm-forwarder', PCMForwarder);
|
||||
58
whisperlivekit/web/recorder_worker.js
Normal file
@@ -0,0 +1,58 @@
|
||||
let sampleRate = 48000;
|
||||
let targetSampleRate = 16000;
|
||||
|
||||
self.onmessage = function (e) {
|
||||
switch (e.data.command) {
|
||||
case 'init':
|
||||
init(e.data.config);
|
||||
break;
|
||||
case 'record':
|
||||
record(e.data.buffer);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
function init(config) {
|
||||
sampleRate = config.sampleRate;
|
||||
targetSampleRate = config.targetSampleRate || 16000;
|
||||
}
|
||||
|
||||
function record(inputBuffer) {
|
||||
const buffer = new Float32Array(inputBuffer);
|
||||
const resampledBuffer = resample(buffer, sampleRate, targetSampleRate);
|
||||
const pcmBuffer = toPCM(resampledBuffer);
|
||||
self.postMessage({ buffer: pcmBuffer }, [pcmBuffer]);
|
||||
}
|
||||
|
||||
function resample(buffer, from, to) {
|
||||
if (from === to) {
|
||||
return buffer;
|
||||
}
|
||||
const ratio = from / to;
|
||||
const newLength = Math.round(buffer.length / ratio);
|
||||
const result = new Float32Array(newLength);
|
||||
let offsetResult = 0;
|
||||
let offsetBuffer = 0;
|
||||
while (offsetResult < result.length) {
|
||||
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
|
||||
let accum = 0, count = 0;
|
||||
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
|
||||
accum += buffer[i];
|
||||
count++;
|
||||
}
|
||||
result[offsetResult] = accum / count;
|
||||
offsetResult++;
|
||||
offsetBuffer = nextOffsetBuffer;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function toPCM(input) {
|
||||
const buffer = new ArrayBuffer(input.length * 2);
|
||||
const view = new DataView(buffer);
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, input[i]));
|
||||
view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
1
whisperlivekit/web/src/language.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 976 B |
1
whisperlivekit/web/src/settings.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||
|
After Width: | Height: | Size: 982 B |
1
whisperlivekit/web/src/silence.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>
|
||||
|
After Width: | Height: | Size: 984 B |
1
whisperlivekit/web/src/speaker.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>
|
||||
|
After Width: | Height: | Size: 592 B |
1
whisperlivekit/web/src/translate.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>
|
||||
|
After Width: | Height: | Size: 650 B |
@@ -23,6 +23,24 @@ def get_inline_ui_html():
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||
js_content = f.read()
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||
worklet_code = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||
worker_code = f.read()
|
||||
|
||||
js_content = js_content.replace(
|
||||
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
|
||||
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||
)
|
||||
js_content = js_content.replace(
|
||||
'recorderWorker = new Worker("/web/recorder_worker.js");',
|
||||
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||
'recorderWorker = new Worker(workerUrl);'
|
||||
)
|
||||
|
||||
# SVG files
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||
system_svg = f.read()
|
||||
@@ -33,15 +51,18 @@ def get_inline_ui_html():
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||
dark_svg = f.read()
|
||||
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
|
||||
settings = f.read()
|
||||
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
# Replace external references
|
||||
html_content = html_content.replace(
|
||||
'<link rel="stylesheet" href="/web/live_transcription.css" />',
|
||||
'<link rel="stylesheet" href="live_transcription.css" />',
|
||||
f'<style>\n{css_content}\n</style>'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<script src="/web/live_transcription.js"></script>',
|
||||
'<script src="live_transcription.js"></script>',
|
||||
f'<script>\n{js_content}\n</script>'
|
||||
)
|
||||
|
||||
@@ -61,6 +82,11 @@ def get_inline_ui_html():
|
||||
f'<img src="{dark_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="web/src/settings.svg" alt="Settings" />',
|
||||
f'<img src="{settings_uri}" alt="" />'
|
||||
)
|
||||
|
||||
return html_content
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,14 +11,14 @@ class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||
|
||||
def with_offset(self, offset: float) -> ASRToken:
|
||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||
@@ -27,7 +27,7 @@ class ASRBase:
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
def load_model(self, modelsize, cache_dir, model_dir):
|
||||
def load_model(self, model_size, cache_dir, model_dir):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
@@ -41,7 +41,7 @@ class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped as the backend."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
@@ -49,7 +49,7 @@ class WhisperTimestampedASR(ASRBase):
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
logger.debug("ignoring model_dir, not implemented")
|
||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||
return whisper.load_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = self.transcribe_timestamped(
|
||||
@@ -88,17 +88,17 @@ class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||
f"modelsize and cache_dir parameters are not used.")
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = modelsize
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
device = "auto" # Allow CTranslate2 to decide available device
|
||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||
|
||||
@@ -149,18 +149,18 @@ class MLXWhisper(ASRBase):
|
||||
"""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = self.translate_model_name(modelsize)
|
||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
dtype = mx.float16
|
||||
|
||||
@@ -106,9 +106,6 @@ class OnlineASRProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
@@ -119,13 +116,14 @@ class OnlineASRProcessor:
|
||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.tokenize = asr.tokenizer
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.confidence_validation = asr.confidence_validation
|
||||
self.global_time_offset = 0.0
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
self.buffer_trimming_way = asr.buffer_trimming
|
||||
self.buffer_trimming_sec = asr.buffer_trimming_sec
|
||||
|
||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,11 +64,23 @@ def create_tokenizer(lan):
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(args):
|
||||
backend = args.backend
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
task,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend = backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
@@ -77,34 +90,33 @@ def backend_factory(args):
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
size = args.model
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
|
||||
asr = asr_cls(
|
||||
modelsize=size,
|
||||
lan=args.lan,
|
||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
||||
model_dir=getattr(args, 'model_dir', None),
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_dir,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
# Apply common configurations
|
||||
if getattr(args, "vad", False): # Checks if VAD argument is present and True
|
||||
logger.info("Setting VAD filter")
|
||||
asr.use_vad()
|
||||
|
||||
language = args.lan
|
||||
if args.task == "translate":
|
||||
if backend != "simulstreaming":
|
||||
asr.set_translate_task()
|
||||
if task == "translate":
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if args.buffer_trimming == "sentence":
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
return asr, tokenizer
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
return asr
|
||||