Compare commits
18 Commits
ScriptProc
...
0.2.11
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 |
23
DEV_NOTES.md
@@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes
|
|||||||
Encoder weights: 15268874 bytes
|
Encoder weights: 15268874 bytes
|
||||||
|
|
||||||
|
|
||||||
|
# 2. Translation: Faster model for each system
|
||||||
|
|
||||||
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
## Benchmark Results
|
||||||
|
|
||||||
|
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||||
|
|
||||||
|
### Standard Transformers vs CTranslate2
|
||||||
|
|
||||||
|
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||||
|
|-----------|-------------------------|---------------------------|---------|
|
||||||
|
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||||
|
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||||
|
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||||
|
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||||
|
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||||
|
|
||||||
|
**Results:**
|
||||||
|
- Total Standard time: 4.1068s
|
||||||
|
- Total CTranslate2 time: 8.5476s
|
||||||
|
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||||
|
|
||||||
|
|
||||||
|
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||||
|
|
||||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||||
|
|
||||||
|
|||||||
72
README.md
@@ -18,9 +18,9 @@ Real-time speech transcription directly to your browser, with a ready-to-use bac
|
|||||||
|
|
||||||
#### Powered by Leading Research:
|
#### Powered by Leading Research:
|
||||||
|
|
||||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
|
- [SimulStreaming](https://github.com/ufalSimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||||
@@ -42,15 +42,6 @@ pip install whisperlivekit
|
|||||||
```
|
```
|
||||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||||
|
|
||||||
|
|
||||||
> **FFmpeg is required** and must be installed before using WhisperLiveKit
|
|
||||||
>
|
|
||||||
> | OS | How to install |
|
|
||||||
> |-----------|-------------|
|
|
||||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
|
||||||
> | MacOS | `brew install ffmpeg` |
|
|
||||||
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
|
|
||||||
|
|
||||||
#### Quick Start
|
#### Quick Start
|
||||||
1. **Start the transcription server:**
|
1. **Start the transcription server:**
|
||||||
```bash
|
```bash
|
||||||
@@ -86,11 +77,11 @@ See **Parameters & Configuration** below on how to use them.
|
|||||||
**Command-line Interface**: Start the transcription server with various options:
|
**Command-line Interface**: Start the transcription server with various options:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Use better model than default (small)
|
# Large model and translate from french to danish
|
||||||
whisperlivekit-server --model large-v3
|
whisperlivekit-server --model large-v3 --language fr --target-language da
|
||||||
|
|
||||||
# Advanced configuration with diarization and language
|
# Diarization and server listening on */80
|
||||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -137,26 +128,15 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|
|
||||||
## Parameters & Configuration
|
## Parameters & Configuration
|
||||||
|
|
||||||
An important list of parameters can be changed. But what *should* you change?
|
|
||||||
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
|
||||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
|
|
||||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
|
||||||
- `--warmup-file`, if you have one
|
|
||||||
- `--task translate`, to translate in english
|
|
||||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
|
||||||
- `--diarization`, if you want to use it.
|
|
||||||
- [BETA] `--target-language`, to translate using NLLB. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly.
|
|
||||||
|
|
||||||
### Full list of parameters :
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--model` | Whisper model size. | `small` |
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` |
|
||||||
| `--language` | Source language code or `auto` | `auto` |
|
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
| `--task` | Set to `translate` to translate to english | `transcribe` |
|
| `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
|
||||||
| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` |
|
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
|
||||||
| `--backend` | Processing backend | `simulstreaming` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
||||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
@@ -164,8 +144,19 @@ An important list of parameters can be changed. But what *should* you change?
|
|||||||
| `--port` | Server port | `8000` |
|
| `--port` | Server port | `8000` |
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. | `False` |
|
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||||
|
|
||||||
|
| Translation options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||||
|
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||||
|
|
||||||
|
| Diarization options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||||
|
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
| SimulStreaming backend options | Description | Default |
|
| SimulStreaming backend options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
@@ -184,25 +175,16 @@ An important list of parameters can be changed. But what *should* you change?
|
|||||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
| WhisperStreaming backend options | Description | Default |
|
| WhisperStreaming backend options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||||
|
|
||||||
| Diarization options | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
|
||||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
|
||||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
|
||||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
|
||||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
|
||||||
|
|
||||||
|
|
||||||
> For diarization using Diart, you need access to pyannote.audio models:
|
|
||||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||||
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
|
||||||
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
|
||||||
>4. Login with HuggingFace: `huggingface-cli login`
|
|
||||||
|
|
||||||
### 🚀 Deployment Guide
|
### 🚀 Deployment Guide
|
||||||
|
|
||||||
|
|||||||
BIN
architecture.png
|
Before Width: | Height: | Size: 368 KiB After Width: | Height: | Size: 390 KiB |
@@ -1,4 +1,4 @@
|
|||||||
# Available model sizes:
|
# Available Whisper model sizes:
|
||||||
|
|
||||||
- tiny.en (english only)
|
- tiny.en (english only)
|
||||||
- tiny
|
- tiny
|
||||||
@@ -71,3 +71,39 @@
|
|||||||
3. Good hardware and want best quality? → `large-v3`
|
3. Good hardware and want best quality? → `large-v3`
|
||||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||||
|
|
||||||
|
|
||||||
|
_______________________
|
||||||
|
|
||||||
|
# Translation Models and Backend
|
||||||
|
|
||||||
|
**Language Support**: ~200 languages
|
||||||
|
|
||||||
|
## Distilled Model Sizes Available
|
||||||
|
|
||||||
|
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||||
|
|-------|------|------------|-------------|-------------|---------|
|
||||||
|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||||
|
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||||
|
|
||||||
|
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||||
|
|
||||||
|
## Backend Performance
|
||||||
|
|
||||||
|
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||||
|
|---------|---------------|--------------|--------------|
|
||||||
|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||||
|
| Transformers | Baseline | High | None |
|
||||||
|
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||||
|
|
||||||
|
**Metrics**:
|
||||||
|
- CTranslate2: 50-100+ tokens/sec
|
||||||
|
- Transformers: 10-30 tokens/sec
|
||||||
|
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||||
|
|
||||||
|
## Quick Decision Matrix
|
||||||
|
|
||||||
|
**Choose 600M**: Limited resources, close to 0 lag
|
||||||
|
**Choose 1.3B**: Quality matters
|
||||||
|
**Choose Transformers**: On Apple Silicon
|
||||||
|
|
||||||
|
|||||||
BIN
demo.png
|
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 985 KiB |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.9"
|
version = "0.2.11"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ from time import time, sleep
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State
|
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
from whisperlivekit.results_formater import format_output
|
from whisperlivekit.results_formater import format_output
|
||||||
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
# Set up logging once
|
# Set up logging once
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -48,7 +49,7 @@ class AudioProcessor:
|
|||||||
self.bytes_per_sample = 2
|
self.bytes_per_sample = 2
|
||||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||||
self.is_pcm_input = True
|
self.is_pcm_input = self.args.pcm_input
|
||||||
self.debug = False
|
self.debug = False
|
||||||
|
|
||||||
# State management
|
# State management
|
||||||
@@ -57,14 +58,17 @@ class AudioProcessor:
|
|||||||
self.silence_duration = 0.0
|
self.silence_duration = 0.0
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.translated_segments = []
|
self.translated_segments = []
|
||||||
self.buffer_transcription = ""
|
self.buffer_transcription = Transcript()
|
||||||
self.buffer_diarization = ""
|
|
||||||
self.end_buffer = 0
|
self.end_buffer = 0
|
||||||
self.end_attributed_speaker = 0
|
self.end_attributed_speaker = 0
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||||
self.sep = " " # Default separator
|
self.sep = " " # Default separator
|
||||||
self.last_response_content = FrontData()
|
self.last_response_content = FrontData()
|
||||||
|
self.last_detected_speaker = None
|
||||||
|
self.speaker_languages = {}
|
||||||
|
self.cumulative_pcm_len = 0
|
||||||
|
self.diarization_before_transcription = False
|
||||||
|
|
||||||
# Models and processing
|
# Models and processing
|
||||||
self.asr = models.asr
|
self.asr = models.asr
|
||||||
@@ -75,6 +79,20 @@ class AudioProcessor:
|
|||||||
else:
|
else:
|
||||||
self.vac = None
|
self.vac = None
|
||||||
|
|
||||||
|
self.ffmpeg_manager = None
|
||||||
|
self.ffmpeg_reader_task = None
|
||||||
|
self._ffmpeg_error = None
|
||||||
|
|
||||||
|
if not self.is_pcm_input:
|
||||||
|
self.ffmpeg_manager = FFmpegManager(
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
channels=self.channels
|
||||||
|
)
|
||||||
|
async def handle_ffmpeg_error(error_type: str):
|
||||||
|
logger.error(f"FFmpeg error: {error_type}")
|
||||||
|
self._ffmpeg_error = error_type
|
||||||
|
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||||
|
|
||||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||||
@@ -84,33 +102,20 @@ class AudioProcessor:
|
|||||||
self.diarization_task = None
|
self.diarization_task = None
|
||||||
self.watchdog_task = None
|
self.watchdog_task = None
|
||||||
self.all_tasks_for_cleanup = []
|
self.all_tasks_for_cleanup = []
|
||||||
|
self.online_translation = None
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||||
self.sep = self.online.asr.sep
|
self.sep = self.online.asr.sep
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||||
if self.args.target_language:
|
if models.translation_model:
|
||||||
self.online_translation = online_translation_factory(self.args, models.translation_model)
|
self.online_translation = online_translation_factory(self.args, models.translation_model)
|
||||||
|
|
||||||
def convert_pcm_to_float(self, pcm_buffer):
|
def convert_pcm_to_float(self, pcm_buffer):
|
||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def update_transcription(self, new_tokens, buffer, end_buffer):
|
|
||||||
"""Thread-safe update of transcription with new data."""
|
|
||||||
async with self.lock:
|
|
||||||
self.tokens.extend(new_tokens)
|
|
||||||
self.buffer_transcription = buffer
|
|
||||||
self.end_buffer = end_buffer
|
|
||||||
|
|
||||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
|
||||||
"""Thread-safe update of diarization with new data."""
|
|
||||||
async with self.lock:
|
|
||||||
self.end_attributed_speaker = end_attributed_speaker
|
|
||||||
if buffer_diarization:
|
|
||||||
self.buffer_diarization = buffer_diarization
|
|
||||||
|
|
||||||
async def add_dummy_token(self):
|
async def add_dummy_token(self):
|
||||||
"""Placeholder token when no transcription is available."""
|
"""Placeholder token when no transcription is available."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
@@ -139,7 +144,6 @@ class AudioProcessor:
|
|||||||
tokens=self.tokens.copy(),
|
tokens=self.tokens.copy(),
|
||||||
translated_segments=self.translated_segments.copy(),
|
translated_segments=self.translated_segments.copy(),
|
||||||
buffer_transcription=self.buffer_transcription,
|
buffer_transcription=self.buffer_transcription,
|
||||||
buffer_diarization=self.buffer_diarization,
|
|
||||||
end_buffer=self.end_buffer,
|
end_buffer=self.end_buffer,
|
||||||
end_attributed_speaker=self.end_attributed_speaker,
|
end_attributed_speaker=self.end_attributed_speaker,
|
||||||
remaining_time_transcription=remaining_transcription,
|
remaining_time_transcription=remaining_transcription,
|
||||||
@@ -151,10 +155,60 @@ class AudioProcessor:
|
|||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.translated_segments = []
|
self.translated_segments = []
|
||||||
self.buffer_transcription = self.buffer_diarization = ""
|
self.buffer_transcription = Transcript()
|
||||||
self.end_buffer = self.end_attributed_speaker = 0
|
self.end_buffer = self.end_attributed_speaker = 0
|
||||||
self.beg_loop = time()
|
self.beg_loop = time()
|
||||||
|
|
||||||
|
async def ffmpeg_stdout_reader(self):
|
||||||
|
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||||
|
beg = time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if self.is_stopping:
|
||||||
|
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
|
||||||
|
break
|
||||||
|
|
||||||
|
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
|
||||||
|
if state == FFmpegState.FAILED:
|
||||||
|
logger.error("FFmpeg is in FAILED state, cannot read data")
|
||||||
|
break
|
||||||
|
elif state == FFmpegState.STOPPED:
|
||||||
|
logger.info("FFmpeg is stopped")
|
||||||
|
break
|
||||||
|
elif state != FFmpegState.RUNNING:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_time = time()
|
||||||
|
elapsed_time = max(0.0, current_time - beg)
|
||||||
|
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
|
||||||
|
beg = current_time
|
||||||
|
|
||||||
|
chunk = await self.ffmpeg_manager.read_data(buffer_size)
|
||||||
|
if not chunk:
|
||||||
|
# No data currently available
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.pcm_buffer.extend(chunk)
|
||||||
|
await self.handle_pcm_data()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("ffmpeg_stdout_reader cancelled.")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||||
|
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||||
|
if not self.diarization_before_transcription and self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(SENTINEL)
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(SENTINEL)
|
||||||
|
if self.online_translation:
|
||||||
|
await self.translation_queue.put(SENTINEL)
|
||||||
|
|
||||||
async def transcription_processor(self):
|
async def transcription_processor(self):
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
@@ -167,11 +221,6 @@ class AudioProcessor:
|
|||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.online:
|
|
||||||
logger.warning("Transcription processor: self.online not initialized.")
|
|
||||||
self.transcription_queue.task_done()
|
|
||||||
continue
|
|
||||||
|
|
||||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||||
@@ -179,17 +228,16 @@ class AudioProcessor:
|
|||||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||||
if self.tokens:
|
if self.tokens:
|
||||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||||
logger.info(asr_processing_logs)
|
logger.info(asr_processing_logs)
|
||||||
|
|
||||||
if type(item) is Silence:
|
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(item, ChangeSpeaker):
|
||||||
if isinstance(item, np.ndarray):
|
self.online.new_speaker(item)
|
||||||
|
elif isinstance(item, np.ndarray):
|
||||||
pcm_array = item
|
pcm_array = item
|
||||||
else:
|
|
||||||
raise Exception('item should be pcm_array')
|
logger.info(asr_processing_logs)
|
||||||
|
|
||||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||||
@@ -198,32 +246,30 @@ class AudioProcessor:
|
|||||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter)
|
||||||
|
|
||||||
# Get buffer information
|
_buffer_transcript = self.online.get_buffer()
|
||||||
_buffer_transcript_obj = self.online.get_buffer()
|
buffer_text = _buffer_transcript.text
|
||||||
buffer_text = _buffer_transcript_obj.text
|
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||||
if buffer_text.startswith(validated_text):
|
if buffer_text.startswith(validated_text):
|
||||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||||
|
|
||||||
candidate_end_times = [self.end_buffer]
|
candidate_end_times = [self.end_buffer]
|
||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
candidate_end_times.append(new_tokens[-1].end)
|
candidate_end_times.append(new_tokens[-1].end)
|
||||||
|
|
||||||
if _buffer_transcript_obj.end is not None:
|
if _buffer_transcript.end is not None:
|
||||||
candidate_end_times.append(_buffer_transcript_obj.end)
|
candidate_end_times.append(_buffer_transcript.end)
|
||||||
|
|
||||||
candidate_end_times.append(current_audio_processed_upto)
|
candidate_end_times.append(current_audio_processed_upto)
|
||||||
|
|
||||||
new_end_buffer = max(candidate_end_times)
|
async with self.lock:
|
||||||
|
self.tokens.extend(new_tokens)
|
||||||
|
self.buffer_transcription = _buffer_transcript
|
||||||
|
self.end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
await self.update_transcription(
|
if self.translation_queue:
|
||||||
new_tokens, buffer_text, new_end_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_tokens and self.args.target_language and self.translation_queue:
|
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
await self.translation_queue.put(token)
|
await self.translation_queue.put(token)
|
||||||
|
|
||||||
@@ -247,8 +293,7 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def diarization_processor(self, diarization_obj):
|
async def diarization_processor(self, diarization_obj):
|
||||||
"""Process audio chunks for speaker diarization."""
|
"""Process audio chunks for speaker diarization."""
|
||||||
buffer_diarization = ""
|
self.current_speaker = 0
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item = await self.diarization_queue.get()
|
item = await self.diarization_queue.get()
|
||||||
@@ -256,30 +301,36 @@ class AudioProcessor:
|
|||||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
break
|
break
|
||||||
|
elif type(item) is Silence:
|
||||||
if type(item) is Silence:
|
|
||||||
cumulative_pcm_duration_stream_time += item.duration
|
|
||||||
diarization_obj.insert_silence(item.duration)
|
diarization_obj.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(item, np.ndarray):
|
||||||
if isinstance(item, np.ndarray):
|
|
||||||
pcm_array = item
|
pcm_array = item
|
||||||
else:
|
else:
|
||||||
raise Exception('item should be pcm_array')
|
raise Exception('item should be pcm_array')
|
||||||
|
|
||||||
# Process diarization
|
# Process diarization
|
||||||
await diarization_obj.diarize(pcm_array)
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
segments = diarization_obj.get_segments()
|
||||||
|
|
||||||
async with self.lock:
|
if self.diarization_before_transcription:
|
||||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
if segments and segments[-1].speaker != self.current_speaker:
|
||||||
self.tokens,
|
self.current_speaker = segments[-1].speaker
|
||||||
use_punctuation_split=self.args.punctuation_split
|
cut_at = int(segments[-1].start*16000 - (self.cumulative_pcm_len))
|
||||||
)
|
await self.transcription_queue.put(pcm_array[cut_at:])
|
||||||
if len(self.tokens) > 0:
|
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=cut_at))
|
||||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
await self.transcription_queue.put(pcm_array[:cut_at])
|
||||||
if buffer_diarization:
|
else:
|
||||||
self.buffer_diarization = buffer_diarization
|
await self.transcription_queue.put(pcm_array)
|
||||||
|
else:
|
||||||
|
async with self.lock:
|
||||||
|
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||||
|
self.tokens,
|
||||||
|
use_punctuation_split=self.args.punctuation_split
|
||||||
|
)
|
||||||
|
self.cumulative_pcm_len += len(pcm_array)
|
||||||
|
if len(self.tokens) > 0:
|
||||||
|
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -289,20 +340,23 @@ class AudioProcessor:
|
|||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
logger.info("Diarization processor task finished.")
|
logger.info("Diarization processor task finished.")
|
||||||
|
|
||||||
async def translation_processor(self, online_translation):
|
async def translation_processor(self):
|
||||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||||
# And the speaker is attributed given the segments used for the translation
|
# And the speaker is attributed given the segments used for the translation
|
||||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
token = await self.translation_queue.get() #block until at least 1 token
|
item = await self.translation_queue.get() #block until at least 1 token
|
||||||
if token is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Translation processor received sentinel. Finishing.")
|
logger.debug("Translation processor received sentinel. Finishing.")
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
break
|
break
|
||||||
|
elif type(item) is Silence:
|
||||||
|
self.online_translation.insert_silence(item.duration)
|
||||||
|
continue
|
||||||
|
|
||||||
# get all the available tokens for translation. The more words, the more precise
|
# get all the available tokens for translation. The more words, the more precise
|
||||||
tokens_to_process = [token]
|
tokens_to_process = [item]
|
||||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||||
|
|
||||||
sentinel_found = False
|
sentinel_found = False
|
||||||
@@ -312,9 +366,8 @@ class AudioProcessor:
|
|||||||
break
|
break
|
||||||
tokens_to_process.append(additional_token)
|
tokens_to_process.append(additional_token)
|
||||||
if tokens_to_process:
|
if tokens_to_process:
|
||||||
online_translation.insert_tokens(tokens_to_process)
|
self.online_translation.insert_tokens(tokens_to_process)
|
||||||
self.translated_segments = await asyncio.to_thread(online_translation.process)
|
self.translated_segments = await asyncio.to_thread(self.online_translation.process)
|
||||||
|
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
for _ in additional_tokens:
|
for _ in additional_tokens:
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
@@ -326,7 +379,7 @@ class AudioProcessor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in translation_processor: {e}")
|
logger.warning(f"Exception in translation_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
if 'token' in locals() and token is not SENTINEL:
|
if 'token' in locals() and item is not SENTINEL:
|
||||||
self.translation_queue.task_done()
|
self.translation_queue.task_done()
|
||||||
if 'additional_tokens' in locals():
|
if 'additional_tokens' in locals():
|
||||||
for _ in additional_tokens:
|
for _ in additional_tokens:
|
||||||
@@ -337,6 +390,16 @@ class AudioProcessor:
|
|||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
# If FFmpeg error occurred, notify front-end
|
||||||
|
if self._ffmpeg_error:
|
||||||
|
yield FrontData(
|
||||||
|
status="error",
|
||||||
|
error=f"FFmpeg error: {self._ffmpeg_error}"
|
||||||
|
)
|
||||||
|
self._ffmpeg_error = None
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
# Get current state
|
# Get current state
|
||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
|
|
||||||
@@ -347,7 +410,7 @@ class AudioProcessor:
|
|||||||
state = await self.get_current_state()
|
state = await self.get_current_state()
|
||||||
|
|
||||||
# Format output
|
# Format output
|
||||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
lines, undiarized_text, end_w_silence = format_output(
|
||||||
state,
|
state,
|
||||||
self.silence,
|
self.silence,
|
||||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||||
@@ -355,30 +418,34 @@ class AudioProcessor:
|
|||||||
debug = self.debug,
|
debug = self.debug,
|
||||||
sep=self.sep
|
sep=self.sep
|
||||||
)
|
)
|
||||||
# Handle undiarized text
|
if end_w_silence:
|
||||||
|
buffer_transcription = Transcript()
|
||||||
|
else:
|
||||||
|
buffer_transcription = state.buffer_transcription
|
||||||
|
|
||||||
|
buffer_diarization = ''
|
||||||
if undiarized_text:
|
if undiarized_text:
|
||||||
combined = self.sep.join(undiarized_text)
|
buffer_diarization = self.sep.join(undiarized_text)
|
||||||
if buffer_transcription:
|
|
||||||
combined += self.sep
|
async with self.lock:
|
||||||
await self.update_diarization(state.end_attributed_speaker, combined)
|
self.end_attributed_speaker = state.end_attributed_speaker
|
||||||
buffer_diarization = combined
|
|
||||||
|
|
||||||
response_status = "active_transcription"
|
response_status = "active_transcription"
|
||||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||||
response_status = "no_audio_detected"
|
response_status = "no_audio_detected"
|
||||||
lines = []
|
lines = []
|
||||||
elif response_status == "active_transcription" and not lines:
|
elif not lines:
|
||||||
lines = [Line(
|
lines = [Line(
|
||||||
speaker=1,
|
speaker=1,
|
||||||
start=state.get("end_buffer", 0),
|
start=state.end_buffer,
|
||||||
end=state.get("end_buffer", 0)
|
end=state.end_buffer
|
||||||
)]
|
)]
|
||||||
|
|
||||||
response = FrontData(
|
response = FrontData(
|
||||||
status=response_status,
|
status=response_status,
|
||||||
lines=lines,
|
lines=lines,
|
||||||
buffer_transcription=buffer_transcription,
|
buffer_transcription=buffer_transcription.text.strip(),
|
||||||
buffer_diarization=buffer_diarization,
|
buffer_diarization=buffer_diarization.strip(),
|
||||||
remaining_time_transcription=state.remaining_time_transcription,
|
remaining_time_transcription=state.remaining_time_transcription,
|
||||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||||
)
|
)
|
||||||
@@ -412,6 +479,21 @@ class AudioProcessor:
|
|||||||
self.all_tasks_for_cleanup = []
|
self.all_tasks_for_cleanup = []
|
||||||
processing_tasks_for_watchdog = []
|
processing_tasks_for_watchdog = []
|
||||||
|
|
||||||
|
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||||
|
if not self.is_pcm_input:
|
||||||
|
success = await self.ffmpeg_manager.start()
|
||||||
|
if not success:
|
||||||
|
logger.error("Failed to start FFmpeg manager")
|
||||||
|
async def error_generator():
|
||||||
|
yield FrontData(
|
||||||
|
status="error",
|
||||||
|
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||||
|
)
|
||||||
|
return error_generator()
|
||||||
|
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||||
|
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||||
|
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||||
|
|
||||||
if self.args.transcription and self.online:
|
if self.args.transcription and self.online:
|
||||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||||
@@ -422,8 +504,8 @@ class AudioProcessor:
|
|||||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||||
|
|
||||||
if self.args.target_language and self.args.lan != 'auto':
|
if self.online_translation:
|
||||||
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
|
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||||
processing_tasks_for_watchdog.append(self.translation_task)
|
processing_tasks_for_watchdog.append(self.translation_task)
|
||||||
|
|
||||||
@@ -462,13 +544,20 @@ class AudioProcessor:
|
|||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||||
if created_tasks:
|
if created_tasks:
|
||||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||||
logger.info("All processing tasks cancelled or finished.")
|
logger.info("All processing tasks cancelled or finished.")
|
||||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
|
||||||
self.diarization.close()
|
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||||
logger.info("AudioProcessor cleanup complete.")
|
try:
|
||||||
|
await self.ffmpeg_manager.stop()
|
||||||
|
logger.info("FFmpeg manager stopped.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
||||||
|
if self.args.diarization and hasattr(self, 'dianization') and hasattr(self.diarization, 'close'):
|
||||||
|
self.diarization.close()
|
||||||
|
logger.info("AudioProcessor cleanup complete.")
|
||||||
|
|
||||||
|
|
||||||
async def process_audio(self, message):
|
async def process_audio(self, message):
|
||||||
@@ -484,6 +573,9 @@ class AudioProcessor:
|
|||||||
if self.transcription_queue:
|
if self.transcription_queue:
|
||||||
await self.transcription_queue.put(SENTINEL)
|
await self.transcription_queue.put(SENTINEL)
|
||||||
|
|
||||||
|
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||||
|
await self.ffmpeg_manager.stop()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.is_stopping:
|
if self.is_stopping:
|
||||||
@@ -493,6 +585,17 @@ class AudioProcessor:
|
|||||||
if self.is_pcm_input:
|
if self.is_pcm_input:
|
||||||
self.pcm_buffer.extend(message)
|
self.pcm_buffer.extend(message)
|
||||||
await self.handle_pcm_data()
|
await self.handle_pcm_data()
|
||||||
|
else:
|
||||||
|
if not self.ffmpeg_manager:
|
||||||
|
logger.error("FFmpeg manager not initialized for non-PCM input.")
|
||||||
|
return
|
||||||
|
success = await self.ffmpeg_manager.write_data(message)
|
||||||
|
if not success:
|
||||||
|
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||||
|
if ffmpeg_state == FFmpegState.FAILED:
|
||||||
|
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to write audio data to FFmpeg")
|
||||||
|
|
||||||
async def handle_pcm_data(self):
|
async def handle_pcm_data(self):
|
||||||
# Process when enough data
|
# Process when enough data
|
||||||
@@ -524,13 +627,15 @@ class AudioProcessor:
|
|||||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||||
|
|
||||||
if silence_buffer:
|
if silence_buffer:
|
||||||
if self.args.transcription and self.transcription_queue:
|
if not self.diarization_before_transcription and self.transcription_queue:
|
||||||
await self.transcription_queue.put(silence_buffer)
|
await self.transcription_queue.put(silence_buffer)
|
||||||
if self.args.diarization and self.diarization_queue:
|
if self.args.diarization and self.diarization_queue:
|
||||||
await self.diarization_queue.put(silence_buffer)
|
await self.diarization_queue.put(silence_buffer)
|
||||||
|
if self.translation_queue:
|
||||||
|
await self.translation_queue.put(silence_buffer)
|
||||||
|
|
||||||
if not self.silence:
|
if not self.silence:
|
||||||
if self.args.transcription and self.transcription_queue:
|
if not self.diarization_before_transcription and self.transcription_queue:
|
||||||
await self.transcription_queue.put(pcm_array.copy())
|
await self.transcription_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
if self.args.diarization and self.diarization_queue:
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from starlette.staticfiles import StaticFiles
|
|
||||||
import pathlib
|
|
||||||
import whisperlivekit.web as webpkg
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
@@ -19,15 +16,6 @@ transcription_engine = None
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
#to remove after 0.2.8
|
|
||||||
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
|
|
||||||
logger.warning(f"""
|
|
||||||
{'='*50}
|
|
||||||
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
|
|
||||||
{'='*50}
|
|
||||||
""")
|
|
||||||
|
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
transcription_engine = TranscriptionEngine(
|
transcription_engine = TranscriptionEngine(
|
||||||
**vars(args),
|
**vars(args),
|
||||||
@@ -42,8 +30,6 @@ app.add_middleware(
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
|
||||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
@@ -73,6 +59,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
logger.info("WebSocket connection opened.")
|
logger.info("WebSocket connection opened.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send config to client: {e}")
|
||||||
|
|
||||||
results_generator = await audio_processor.create_tasks()
|
results_generator = await audio_processor.create_tasks()
|
||||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,12 @@ class TranscriptionEngine:
|
|||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
"pcm_input": False,
|
"pcm_input": False,
|
||||||
|
|
||||||
# whisperstreaming params:
|
# whisperstreaming params:
|
||||||
"buffer_trimming": "segment",
|
"buffer_trimming": "segment",
|
||||||
"confidence_validation": False,
|
"confidence_validation": False,
|
||||||
"buffer_trimming_sec": 15,
|
"buffer_trimming_sec": 15,
|
||||||
|
|
||||||
# simulstreaming params:
|
# simulstreaming params:
|
||||||
"disable_fast_encoder": False,
|
"disable_fast_encoder": False,
|
||||||
"frame_threshold": 25,
|
"frame_threshold": 25,
|
||||||
@@ -61,10 +63,15 @@ class TranscriptionEngine:
|
|||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"model_path": './base.pt',
|
"model_path": './base.pt',
|
||||||
"diarization_backend": "sortformer",
|
"diarization_backend": "sortformer",
|
||||||
|
|
||||||
# diarization params:
|
# diarization params:
|
||||||
"disable_punctuation_split" : False,
|
"disable_punctuation_split" : False,
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
"embedding_model": "pyannote/embedding",
|
"embedding_model": "pyannote/embedding",
|
||||||
|
|
||||||
|
# translation params:
|
||||||
|
"nllb_backend": "ctranslate2",
|
||||||
|
"nllb_size": "600M"
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
@@ -138,12 +145,11 @@ class TranscriptionEngine:
|
|||||||
|
|
||||||
self.translation_model = None
|
self.translation_model = None
|
||||||
if self.args.target_language:
|
if self.args.target_language:
|
||||||
if self.args.lan == 'auto':
|
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||||
raise Exception('Translation cannot be set with language auto')
|
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||||
else:
|
else:
|
||||||
from whisperlivekit.translation.translation import load_model
|
from whisperlivekit.translation.translation import load_model
|
||||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ class SortformerDiarizationOnline:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tokens with speaker assignments
|
List of tokens with speaker assignments
|
||||||
|
Last speaker_segment
|
||||||
"""
|
"""
|
||||||
with self.segment_lock:
|
with self.segment_lock:
|
||||||
segments = self.speaker_segments.copy()
|
segments = self.speaker_segments.copy()
|
||||||
|
|||||||
197
whisperlivekit/ffmpeg_manager.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Callable
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||||
|
{'='*50}
|
||||||
|
FFmpeg is not installed or not found in your system's PATH.
|
||||||
|
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||||
|
|
||||||
|
If you want to install FFmpeg:
|
||||||
|
|
||||||
|
# Ubuntu/Debian:
|
||||||
|
sudo apt update && sudo apt install ffmpeg
|
||||||
|
|
||||||
|
# macOS (using Homebrew):
|
||||||
|
brew install ffmpeg
|
||||||
|
|
||||||
|
# Windows:
|
||||||
|
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||||
|
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||||
|
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||||
|
|
||||||
|
After installation, please restart the application.
|
||||||
|
{'='*50}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FFmpegState(Enum):
|
||||||
|
STOPPED = "stopped"
|
||||||
|
STARTING = "starting"
|
||||||
|
RUNNING = "running"
|
||||||
|
RESTARTING = "restarting"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
class FFmpegManager:
|
||||||
|
def __init__(self, sample_rate: int = 16000, channels: int = 1):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.process: Optional[asyncio.subprocess.Process] = None
|
||||||
|
self._stderr_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
self.on_error_callback: Optional[Callable[[str], None]] = None
|
||||||
|
|
||||||
|
self.state = FFmpegState.STOPPED
|
||||||
|
self._state_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(self) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.STOPPED:
|
||||||
|
logger.warning(f"FFmpeg already running in state: {self.state}")
|
||||||
|
return False
|
||||||
|
self.state = FFmpegState.STARTING
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-i", "pipe:0",
|
||||||
|
"-f", "s16le",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ac", str(self.channels),
|
||||||
|
"-ar", str(self.sample_rate),
|
||||||
|
"pipe:1"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
self._stderr_task = asyncio.create_task(self._drain_stderr())
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.RUNNING
|
||||||
|
|
||||||
|
logger.info("FFmpeg started.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(ERROR_INSTALL_INSTRUCTIONS)
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("ffmpeg_not_found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting FFmpeg: {e}")
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("start_failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state == FFmpegState.STOPPED:
|
||||||
|
return
|
||||||
|
self.state = FFmpegState.STOPPED
|
||||||
|
|
||||||
|
if self.process:
|
||||||
|
if self.process.stdin and not self.process.stdin.is_closing():
|
||||||
|
self.process.stdin.close()
|
||||||
|
await self.process.stdin.wait_closed()
|
||||||
|
await self.process.wait()
|
||||||
|
self.process = None
|
||||||
|
|
||||||
|
if self._stderr_task:
|
||||||
|
self._stderr_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._stderr_task
|
||||||
|
|
||||||
|
logger.info("FFmpeg stopped.")
|
||||||
|
|
||||||
|
async def write_data(self, data: bytes) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.RUNNING:
|
||||||
|
logger.warning(f"Cannot write, FFmpeg state: {self.state}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process.stdin.write(data)
|
||||||
|
await self.process.stdin.drain()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to FFmpeg: {e}")
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("write_error")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def read_data(self, size: int) -> Optional[bytes]:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.RUNNING:
|
||||||
|
logger.warning(f"Cannot read, FFmpeg state: {self.state}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(
|
||||||
|
self.process.stdout.read(size),
|
||||||
|
timeout=20.0
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("FFmpeg read timeout.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from FFmpeg: {e}")
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("read_error")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_state(self) -> FFmpegState:
|
||||||
|
async with self._state_lock:
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
async def restart(self) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state == FFmpegState.RESTARTING:
|
||||||
|
logger.warning("Restart already in progress.")
|
||||||
|
return False
|
||||||
|
self.state = FFmpegState.RESTARTING
|
||||||
|
|
||||||
|
logger.info("Restarting FFmpeg...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.stop()
|
||||||
|
await asyncio.sleep(1) # short delay before restarting
|
||||||
|
return await self.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during FFmpeg restart: {e}")
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("restart_failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _drain_stderr(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if not self.process or not self.process.stderr:
|
||||||
|
break
|
||||||
|
line = await self.process.stderr.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
logger.debug(f"FFmpeg stderr: {line.decode(errors='ignore').strip()}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("FFmpeg stderr drain task cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||||
@@ -177,7 +177,7 @@ def parse_args():
|
|||||||
"--pcm-input",
|
"--pcm-input",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed."
|
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||||
)
|
)
|
||||||
# SimulStreaming-specific arguments
|
# SimulStreaming-specific arguments
|
||||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||||
@@ -287,6 +287,20 @@ def parse_args():
|
|||||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--nllb-backend",
|
||||||
|
type=str,
|
||||||
|
default="ctranslate2",
|
||||||
|
help="transformers or ctranslate2",
|
||||||
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--nllb-size",
|
||||||
|
type=str,
|
||||||
|
default="600M",
|
||||||
|
help="600M or 1.3B",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.transcription = not args.no_transcription
|
args.transcription = not args.no_transcription
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def blank_to_silence(tokens):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if silence_token: #there was silence but no more
|
if silence_token: #there was silence but no more
|
||||||
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||||
cleaned_tokens.append(
|
cleaned_tokens.append(
|
||||||
silence_token
|
silence_token
|
||||||
)
|
)
|
||||||
@@ -77,15 +77,17 @@ def no_token_to_silence(tokens):
|
|||||||
new_tokens.append(token)
|
new_tokens.append(token)
|
||||||
return new_tokens
|
return new_tokens
|
||||||
|
|
||||||
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
def ends_with_silence(tokens, current_time, vac_detected_silence):
|
||||||
|
end_w_silence = False
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return [], buffer_transcription, buffer_diarization
|
return [], end_w_silence
|
||||||
last_token = tokens[-1]
|
last_token = tokens[-1]
|
||||||
if tokens and current_time and (
|
if tokens and current_time and (
|
||||||
current_time - last_token.end >= END_SILENCE_DURATION
|
current_time - last_token.end >= END_SILENCE_DURATION
|
||||||
or
|
or
|
||||||
(current_time - last_token.end >= 3 and vac_detected_silence)
|
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||||
):
|
):
|
||||||
|
end_w_silence = True
|
||||||
if last_token.speaker == -2:
|
if last_token.speaker == -2:
|
||||||
last_token.end = current_time
|
last_token.end = current_time
|
||||||
else:
|
else:
|
||||||
@@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
|
|||||||
probability=0.95
|
probability=0.95
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
return tokens, end_w_silence
|
||||||
buffer_diarization = ""
|
|
||||||
return tokens, buffer_transcription, buffer_diarization
|
|
||||||
|
|
||||||
|
|
||||||
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
def handle_silences(tokens, current_time, vac_detected_silence):
|
||||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||||
tokens = no_token_to_silence(tokens)
|
tokens = no_token_to_silence(tokens)
|
||||||
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
|
||||||
return tokens, buffer_transcription, buffer_diarization
|
return tokens, end_w_silence
|
||||||
|
|
||||||
@@ -6,11 +6,10 @@ from whisperlivekit.timed_objects import Line, format_time
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
|
||||||
CHECK_AROUND = 4
|
CHECK_AROUND = 4
|
||||||
|
|
||||||
def is_punctuation(token):
|
def is_punctuation(token):
|
||||||
if token.text.strip() in PUNCTUATION_MARKS:
|
if token.is_punctuation():
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -39,26 +38,28 @@ def new_line(
|
|||||||
text = token.text + debug_info,
|
text = token.text + debug_info,
|
||||||
start = token.start,
|
start = token.start,
|
||||||
end = token.end,
|
end = token.end,
|
||||||
|
detected_language=token.detected_language
|
||||||
)
|
)
|
||||||
|
|
||||||
def append_token_to_last_line(lines, sep, token, debug_info):
|
def append_token_to_last_line(lines, sep, token, debug_info):
|
||||||
if token.text:
|
if token.text:
|
||||||
lines[-1].text += sep + token.text + debug_info
|
lines[-1].text += sep + token.text + debug_info
|
||||||
lines[-1].end = token.end
|
lines[-1].end = token.end
|
||||||
|
if not lines[-1].detected_language and token.detected_language:
|
||||||
|
lines[-1].detected_language = token.detected_language
|
||||||
|
|
||||||
|
|
||||||
def format_output(state, silence, current_time, args, debug, sep):
|
def format_output(state, silence, current_time, args, debug, sep):
|
||||||
diarization = args.diarization
|
diarization = args.diarization
|
||||||
disable_punctuation_split = args.disable_punctuation_split
|
disable_punctuation_split = args.disable_punctuation_split
|
||||||
tokens = state.tokens
|
tokens = state.tokens
|
||||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||||
buffer_transcription = state.buffer_transcription
|
|
||||||
buffer_diarization = state.buffer_diarization
|
|
||||||
end_attributed_speaker = state.end_attributed_speaker
|
end_attributed_speaker = state.end_attributed_speaker
|
||||||
|
|
||||||
previous_speaker = -1
|
previous_speaker = -1
|
||||||
lines = []
|
lines = []
|
||||||
undiarized_text = []
|
undiarized_text = []
|
||||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
tokens, end_w_silence = handle_silences(tokens, current_time, silence)
|
||||||
last_punctuation = None
|
last_punctuation = None
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
speaker = token.speaker
|
speaker = token.speaker
|
||||||
@@ -122,15 +123,39 @@ def format_output(state, silence, current_time, args, debug, sep):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
append_token_to_last_line(lines, sep, token, debug_info)
|
append_token_to_last_line(lines, sep, token, debug_info)
|
||||||
if lines and translated_segments:
|
|
||||||
cts_idx = 0 # current_translated_segment_idx
|
|
||||||
for line in lines:
|
|
||||||
while cts_idx < len(translated_segments):
|
|
||||||
ts = translated_segments[cts_idx]
|
|
||||||
if ts and ts.start and ts.start >= line.start and ts.end <= line.end:
|
|
||||||
line.translation += ts.text + ' '
|
|
||||||
cts_idx += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return lines, undiarized_text, buffer_transcription, ''
|
|
||||||
|
|
||||||
|
if lines and translated_segments:
|
||||||
|
unassigned_translated_segments = []
|
||||||
|
for ts in translated_segments:
|
||||||
|
assigned = False
|
||||||
|
for line in lines:
|
||||||
|
if ts and ts.overlaps_with(line):
|
||||||
|
if ts.is_within(line):
|
||||||
|
line.translation += ts.text + ' '
|
||||||
|
assigned = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||||
|
if ts0 and line.overlaps_with(ts0):
|
||||||
|
line.translation += ts0.text + ' '
|
||||||
|
if ts1:
|
||||||
|
unassigned_translated_segments.append(ts1)
|
||||||
|
assigned = True
|
||||||
|
break
|
||||||
|
if not assigned:
|
||||||
|
unassigned_translated_segments.append(ts)
|
||||||
|
|
||||||
|
if unassigned_translated_segments:
|
||||||
|
for line in lines:
|
||||||
|
remaining_segments = []
|
||||||
|
for ts in unassigned_translated_segments:
|
||||||
|
if ts and ts.overlaps_with(line):
|
||||||
|
line.translation += ts.text + ' '
|
||||||
|
else:
|
||||||
|
remaining_segments.append(ts)
|
||||||
|
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||||
|
|
||||||
|
if state.buffer_transcription and lines:
|
||||||
|
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||||
|
|
||||||
|
return lines, undiarized_text, end_w_silence
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import logging
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||||
from whisperlivekit.warmup import load_file
|
from whisperlivekit.warmup import load_file
|
||||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
|
||||||
from .whisper import load_model, tokenizer
|
from .whisper import load_model, tokenizer
|
||||||
from .whisper.audio import TOKENS_PER_SECOND
|
from .whisper.audio import TOKENS_PER_SECOND
|
||||||
import os
|
import os
|
||||||
@@ -23,7 +22,11 @@ try:
|
|||||||
HAS_MLX_WHISPER = True
|
HAS_MLX_WHISPER = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||||
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
|
print(f"""
|
||||||
|
{"="*50}
|
||||||
|
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
||||||
|
{"="*50}
|
||||||
|
""")
|
||||||
HAS_MLX_WHISPER = False
|
HAS_MLX_WHISPER = False
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
HAS_FASTER_WHISPER = False
|
HAS_FASTER_WHISPER = False
|
||||||
@@ -49,8 +52,7 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.global_time_offset = 0.0
|
self.buffer = []
|
||||||
|
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
self.load_new_backend()
|
self.load_new_backend()
|
||||||
@@ -79,7 +81,7 @@ class SimulStreamingOnlineProcessor:
|
|||||||
else:
|
else:
|
||||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
self.global_time_offset = silence_duration + offset
|
self.model.global_time_offset = silence_duration + offset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -91,63 +93,15 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||||
self.model.insert_audio(audio_tensor)
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
|
self.process_iter(is_last=True)
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
self.model.speaker = change_speaker.speaker
|
||||||
|
self.global_time_offset = change_speaker.start
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
return Transcript(
|
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||||
start=None,
|
return concat_buffer
|
||||||
end=None,
|
|
||||||
text='',
|
|
||||||
probability=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def timestamped_text(self, tokens, generation):
|
|
||||||
"""
|
|
||||||
generate timestamped text from tokens and generation data.
|
|
||||||
|
|
||||||
args:
|
|
||||||
tokens: List of tokens to process
|
|
||||||
generation: Dictionary containing generation progress and optionally results
|
|
||||||
|
|
||||||
returns:
|
|
||||||
List of tuples containing (start_time, end_time, word) for each word
|
|
||||||
"""
|
|
||||||
FRAME_DURATION = 0.02
|
|
||||||
if "result" in generation:
|
|
||||||
split_words = generation["result"]["split_words"]
|
|
||||||
split_tokens = generation["result"]["split_tokens"]
|
|
||||||
else:
|
|
||||||
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
|
||||||
progress = generation["progress"]
|
|
||||||
frames = [p["most_attended_frames"][0] for p in progress]
|
|
||||||
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
|
|
||||||
tokens_queue = tokens.copy()
|
|
||||||
timestamped_words = []
|
|
||||||
|
|
||||||
for word, word_tokens in zip(split_words, split_tokens):
|
|
||||||
# start_frame = None
|
|
||||||
# end_frame = None
|
|
||||||
for expected_token in word_tokens:
|
|
||||||
if not tokens_queue or not frames:
|
|
||||||
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
|
|
||||||
|
|
||||||
actual_token = tokens_queue.pop(0)
|
|
||||||
current_frame = frames.pop(0)
|
|
||||||
current_timestamp = absolute_timestamps.pop(0)
|
|
||||||
if actual_token != expected_token:
|
|
||||||
raise ValueError(
|
|
||||||
f"Token mismatch: expected '{expected_token}', "
|
|
||||||
f"got '{actual_token}' at frame {current_frame}"
|
|
||||||
)
|
|
||||||
# if start_frame is None:
|
|
||||||
# start_frame = current_frame
|
|
||||||
# end_frame = current_frame
|
|
||||||
# start_time = start_frame * FRAME_DURATION
|
|
||||||
# end_time = end_frame * FRAME_DURATION
|
|
||||||
start_time = current_timestamp
|
|
||||||
end_time = current_timestamp + 0.1
|
|
||||||
timestamp_entry = (start_time, end_time, word)
|
|
||||||
timestamped_words.append(timestamp_entry)
|
|
||||||
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
|
|
||||||
return timestamped_words
|
|
||||||
|
|
||||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
"""
|
"""
|
||||||
@@ -156,47 +110,14 @@ class SimulStreamingOnlineProcessor:
|
|||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
tokens, generation_progress = self.model.infer(is_last=is_last)
|
timestamped_words = self.model.infer(is_last=is_last)
|
||||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
if timestamped_words and timestamped_words[0].detected_language == None:
|
||||||
|
self.buffer.extend(timestamped_words)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
new_tokens = []
|
self.committed.extend(timestamped_words)
|
||||||
for ts_word in ts_words:
|
self.buffer = []
|
||||||
|
return timestamped_words, self.end
|
||||||
start, end, word = ts_word
|
|
||||||
token = ASRToken(
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
text=word,
|
|
||||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
|
||||||
).with_offset(
|
|
||||||
self.global_time_offset
|
|
||||||
)
|
|
||||||
new_tokens.append(token)
|
|
||||||
|
|
||||||
# identical_tokens = 0
|
|
||||||
# n_new_tokens = len(new_tokens)
|
|
||||||
# if n_new_tokens:
|
|
||||||
|
|
||||||
self.committed.extend(new_tokens)
|
|
||||||
|
|
||||||
# if token in self.committed:
|
|
||||||
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
|
|
||||||
# if pos:
|
|
||||||
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
|
|
||||||
# commited_segment = self.committed[i:i+n_new_tokens]
|
|
||||||
# if commited_segment == new_tokens:
|
|
||||||
# identical_segments +=1
|
|
||||||
# if identical_tokens >= TOO_MANY_REPETITIONS:
|
|
||||||
# logger.warning('Too many repetition, model is stuck. Load a new one')
|
|
||||||
# self.committed = self.committed[:i]
|
|
||||||
# self.load_new_backend()
|
|
||||||
# return [], self.end
|
|
||||||
|
|
||||||
# pos = self.committed.rindex(token)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return new_tokens, self.end
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -226,7 +147,6 @@ class SimulStreamingASR():
|
|||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||||
logger.warning(SIMULSTREAMING_LICENSE)
|
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.transcribe_kargs = {}
|
self.transcribe_kargs = {}
|
||||||
self.original_language = lan
|
self.original_language = lan
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .whisper import load_model, DecodingOptions, tokenizer
|
from .whisper import load_model, DecodingOptions, tokenizer
|
||||||
from .config import AlignAttConfig
|
from .config import AlignAttConfig
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||||
from .whisper.timing import median_filter
|
from .whisper.timing import median_filter
|
||||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||||
@@ -18,6 +19,7 @@ from time import time
|
|||||||
from .token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from ..timed_objects import PUNCTUATION_MARKS
|
||||||
from .generation_progress import *
|
from .generation_progress import *
|
||||||
|
|
||||||
DEC_PAD = 50257
|
DEC_PAD = 50257
|
||||||
@@ -40,12 +42,6 @@ else:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_FASTER_WHISPER = False
|
HAS_FASTER_WHISPER = False
|
||||||
|
|
||||||
# New features added to the original version of Simul-Whisper:
|
|
||||||
# - large-v3 model support
|
|
||||||
# - translation support
|
|
||||||
# - beam search
|
|
||||||
# - prompt -- static vs. non-static
|
|
||||||
# - context
|
|
||||||
class PaddedAlignAttWhisper:
|
class PaddedAlignAttWhisper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -70,7 +66,7 @@ class PaddedAlignAttWhisper:
|
|||||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||||
|
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
|
self.speaker = -1
|
||||||
self.decode_options = DecodingOptions(
|
self.decode_options = DecodingOptions(
|
||||||
language = cfg.language,
|
language = cfg.language,
|
||||||
without_timestamps = True,
|
without_timestamps = True,
|
||||||
@@ -78,7 +74,10 @@ class PaddedAlignAttWhisper:
|
|||||||
)
|
)
|
||||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
# self.create_tokenizer('en')
|
||||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
self.reset_tokenizer_to_auto_next_call = False
|
||||||
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
@@ -153,6 +152,7 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
self.cumulative_time_offset = 0.0
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.first_timestamp = None
|
||||||
|
|
||||||
if self.cfg.max_context_tokens is None:
|
if self.cfg.max_context_tokens is None:
|
||||||
self.max_context_tokens = self.max_text_len
|
self.max_context_tokens = self.max_text_len
|
||||||
@@ -260,7 +260,6 @@ class PaddedAlignAttWhisper:
|
|||||||
self.init_context()
|
self.init_context()
|
||||||
logger.debug(f"Context: {self.context}")
|
logger.debug(f"Context: {self.context}")
|
||||||
if not complete and len(self.segments) > 2:
|
if not complete and len(self.segments) > 2:
|
||||||
logger.debug("keeping last two segments because they are and it is not complete.")
|
|
||||||
self.segments = self.segments[-2:]
|
self.segments = self.segments[-2:]
|
||||||
else:
|
else:
|
||||||
logger.debug("removing all segments.")
|
logger.debug("removing all segments.")
|
||||||
@@ -382,11 +381,11 @@ class PaddedAlignAttWhisper:
|
|||||||
new_segment = True
|
new_segment = True
|
||||||
if len(self.segments) == 0:
|
if len(self.segments) == 0:
|
||||||
logger.debug("No segments, nothing to do")
|
logger.debug("No segments, nothing to do")
|
||||||
return [], {}
|
return []
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
input_segments = torch.cat(self.segments, dim=0)
|
input_segments = torch.cat(self.segments, dim=0)
|
||||||
return [], {}
|
return []
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
if len(self.segments) > 1:
|
if len(self.segments) > 1:
|
||||||
@@ -394,6 +393,13 @@ class PaddedAlignAttWhisper:
|
|||||||
else:
|
else:
|
||||||
input_segments = self.segments[0]
|
input_segments = self.segments[0]
|
||||||
|
|
||||||
|
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
|
||||||
|
# logger.debug("Resetting tokenizer to auto for new sentence.")
|
||||||
|
# self.create_tokenizer(None)
|
||||||
|
# self.detected_language = None
|
||||||
|
# self.init_tokens()
|
||||||
|
# self.reset_tokenizer_to_auto_next_call = False
|
||||||
|
|
||||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||||
beg_encode = time()
|
beg_encode = time()
|
||||||
if self.mlx_encoder:
|
if self.mlx_encoder:
|
||||||
@@ -426,58 +432,38 @@ class PaddedAlignAttWhisper:
|
|||||||
end_encode = time()
|
end_encode = time()
|
||||||
# print('Encoder duration:', end_encode-beg_encode)
|
# print('Encoder duration:', end_encode-beg_encode)
|
||||||
|
|
||||||
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||||
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||||
# logger.debug("mel ")
|
if seconds_since_start >= 2.0:
|
||||||
if self.cfg.language == "auto" and self.detected_language is None:
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
self.create_tokenizer(top_lan)
|
||||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
#self.tokenizer.language = top_lan
|
self.cumulative_time_offset = 0.0
|
||||||
#self.tokenizer.__post_init__()
|
self.init_tokens()
|
||||||
self.create_tokenizer(top_lan)
|
self.init_context()
|
||||||
self.detected_language = top_lan
|
self.detected_language = top_lan
|
||||||
self.init_tokens()
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
|
||||||
|
|
||||||
self.trim_context()
|
self.trim_context()
|
||||||
current_tokens = self._current_tokens()
|
current_tokens = self._current_tokens()
|
||||||
#
|
|
||||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
|
||||||
####################### Decoding loop
|
|
||||||
logger.info("Decoding loop starts\n")
|
|
||||||
|
|
||||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||||
completed = False
|
completed = False
|
||||||
|
# punctuation_stop = False
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
attn_of_alignment_heads = None
|
||||||
most_attended_frame = None
|
most_attended_frame = None
|
||||||
|
|
||||||
token_len_before_decoding = current_tokens.shape[1]
|
token_len_before_decoding = current_tokens.shape[1]
|
||||||
|
|
||||||
generation_progress = []
|
l_absolute_timestamps = []
|
||||||
generation = {
|
|
||||||
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
|
|
||||||
"token_len_before_decoding": token_len_before_decoding,
|
|
||||||
#"fire_detected": fire_detected,
|
|
||||||
"frames_len": content_mel_len,
|
|
||||||
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
|
|
||||||
|
|
||||||
# to be filled later
|
|
||||||
"logits_starting": None,
|
|
||||||
|
|
||||||
# to be filled later
|
|
||||||
"no_speech_prob": None,
|
|
||||||
"no_speech": False,
|
|
||||||
|
|
||||||
# to be filled in the loop
|
|
||||||
"progress": generation_progress,
|
|
||||||
}
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||||
generation_progress_loop = []
|
|
||||||
|
|
||||||
if new_segment:
|
if new_segment:
|
||||||
tokens_for_logits = current_tokens
|
tokens_for_logits = current_tokens
|
||||||
@@ -486,50 +472,26 @@ class PaddedAlignAttWhisper:
|
|||||||
tokens_for_logits = current_tokens[:,-1:]
|
tokens_for_logits = current_tokens[:,-1:]
|
||||||
|
|
||||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||||
if new_segment:
|
|
||||||
generation["logits_starting"] = Logits(logits[:,:,:])
|
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
generation["no_speech_prob"] = no_speech_probs[0]
|
|
||||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
generation["no_speech"] = True
|
|
||||||
logger.info("no speech, stop")
|
logger.info("no speech, stop")
|
||||||
break
|
break
|
||||||
|
|
||||||
logits = logits[:, -1, :] # logits for the last token
|
logits = logits[:, -1, :] # logits for the last token
|
||||||
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
|
|
||||||
|
|
||||||
# supress blank tokens only at the beginning of the segment
|
# supress blank tokens only at the beginning of the segment
|
||||||
if new_segment:
|
if new_segment:
|
||||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
new_segment = False
|
new_segment = False
|
||||||
self.suppress_tokens(logits)
|
self.suppress_tokens(logits)
|
||||||
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
|
|
||||||
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
|
|
||||||
|
|
||||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||||
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
|
|
||||||
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
|
|
||||||
generation_progress_loop.append(("completed",completed))
|
|
||||||
|
|
||||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||||
self.debug_print_tokens(current_tokens)
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
|
|
||||||
# if self.decoder_type == "beam":
|
|
||||||
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
|
|
||||||
|
|
||||||
# logprobs = F.log_softmax(logits.float(), dim=-1)
|
|
||||||
# idx = 0
|
|
||||||
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
|
|
||||||
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
|
|
||||||
# if completed:
|
|
||||||
# self.debug_print_tokens(current_tokens)
|
|
||||||
|
|
||||||
# logger.debug("decode stopped because decoder completed")
|
|
||||||
|
|
||||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||||
for i, attn_mat in enumerate(self.dec_attns):
|
for i, attn_mat in enumerate(self.dec_attns):
|
||||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||||
@@ -548,30 +510,24 @@ class PaddedAlignAttWhisper:
|
|||||||
t = torch.cat(mat, dim=1)
|
t = torch.cat(mat, dim=1)
|
||||||
tmp.append(t)
|
tmp.append(t)
|
||||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||||
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
|
|
||||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||||
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
|
|
||||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||||
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
|
|
||||||
|
|
||||||
# for each beam, the most attended frame is:
|
# for each beam, the most attended frame is:
|
||||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
|
||||||
|
|
||||||
# Calculate absolute timestamps accounting for cumulative offset
|
# Calculate absolute timestamps accounting for cumulative offset
|
||||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||||
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
|
|
||||||
|
|
||||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||||
|
|
||||||
most_attended_frame = most_attended_frames[0].item()
|
most_attended_frame = most_attended_frames[0].item()
|
||||||
|
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||||
|
|
||||||
|
|
||||||
generation_progress.append(dict(generation_progress_loop))
|
|
||||||
logger.debug("current tokens" + str(current_tokens.shape))
|
logger.debug("current tokens" + str(current_tokens.shape))
|
||||||
if completed:
|
if completed:
|
||||||
# # stripping the last token, the eot
|
# # stripping the last token, the eot
|
||||||
@@ -609,66 +565,53 @@ class PaddedAlignAttWhisper:
|
|||||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||||
))
|
))
|
||||||
|
|
||||||
# for k,v in generation.items():
|
|
||||||
# print(k,v,file=sys.stderr)
|
|
||||||
# for x in generation_progress:
|
|
||||||
# for y in x.items():
|
|
||||||
# print("\t\t",*y,file=sys.stderr)
|
|
||||||
# print("\t","----", file=sys.stderr)
|
|
||||||
# print("\t", "end of generation_progress_loop", file=sys.stderr)
|
|
||||||
# sys.exit(1)
|
|
||||||
####################### End of decoding loop
|
|
||||||
|
|
||||||
logger.info("End of decoding loop")
|
|
||||||
|
|
||||||
# if attn_of_alignment_heads is not None:
|
|
||||||
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
|
|
||||||
|
|
||||||
# # Lets' now consider only the top hypothesis in the beam search
|
|
||||||
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
|
|
||||||
|
|
||||||
# # debug print: how is the new token attended?
|
|
||||||
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
|
|
||||||
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
|
|
||||||
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
|
|
||||||
# logger.debug("no token generated")
|
|
||||||
# else: # it is, and the max attention is:
|
|
||||||
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
|
|
||||||
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
|
|
||||||
|
|
||||||
|
|
||||||
# let's now operate only with the top beam hypothesis
|
|
||||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||||
if fire_detected or is_last:
|
|
||||||
|
if fire_detected or is_last: #or punctuation_stop:
|
||||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
else:
|
else:
|
||||||
# going to truncate the tokens after the last space
|
# going to truncate the tokens after the last space
|
||||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||||
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
|
|
||||||
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
|
|
||||||
|
|
||||||
# text_to_split = self.tokenizer.decode(tokens_to_split)
|
|
||||||
# logger.debug(f"text_to_split: {text_to_split}")
|
|
||||||
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
|
|
||||||
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
|
|
||||||
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
|
|
||||||
if len(split_words) > 1:
|
if len(split_words) > 1:
|
||||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||||
else:
|
else:
|
||||||
new_hypothesis = []
|
new_hypothesis = []
|
||||||
|
|
||||||
|
|
||||||
### new hypothesis
|
|
||||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.tokens.append(new_tokens)
|
self.tokens.append(new_tokens)
|
||||||
# TODO: test if this is redundant or not
|
|
||||||
# ret = ret[ret<DEC_PAD]
|
|
||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
|
|
||||||
return new_hypothesis, generation
|
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||||
|
self.first_timestamp = l_absolute_timestamps[0]
|
||||||
|
|
||||||
|
|
||||||
|
timestamped_words = []
|
||||||
|
timestamp_idx = 0
|
||||||
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
try:
|
||||||
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
|
timestamp_entry = ASRToken(
|
||||||
|
start=current_timestamp,
|
||||||
|
end=current_timestamp + 0.1,
|
||||||
|
text= word,
|
||||||
|
probability=0.95,
|
||||||
|
speaker=self.speaker,
|
||||||
|
detected_language=self.detected_language
|
||||||
|
).with_offset(
|
||||||
|
self.global_time_offset
|
||||||
|
)
|
||||||
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Any, List
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
||||||
def format_time(seconds: float) -> str:
|
def format_time(seconds: float) -> str:
|
||||||
"""Format seconds as HH:MM:SS."""
|
"""Format seconds as HH:MM:SS."""
|
||||||
return str(timedelta(seconds=int(seconds)))
|
return str(timedelta(seconds=int(seconds)))
|
||||||
@@ -15,12 +17,35 @@ class TimedText:
|
|||||||
speaker: Optional[int] = -1
|
speaker: Optional[int] = -1
|
||||||
probability: Optional[float] = None
|
probability: Optional[float] = None
|
||||||
is_dummy: Optional[bool] = False
|
is_dummy: Optional[bool] = False
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
|
||||||
@dataclass
|
def is_punctuation(self):
|
||||||
|
return self.text.strip() in PUNCTUATION_MARKS
|
||||||
|
|
||||||
|
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||||
|
return not (self.end <= other.start or other.end <= self.start)
|
||||||
|
|
||||||
|
def is_within(self, other: 'TimedText') -> bool:
|
||||||
|
return other.contains_timespan(self)
|
||||||
|
|
||||||
|
def duration(self) -> float:
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
def contains_time(self, time: float) -> bool:
|
||||||
|
return self.start <= time <= self.end
|
||||||
|
|
||||||
|
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||||
|
return self.start <= other.start and self.end >= other.end
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.text)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
def with_offset(self, offset: float) -> "ASRToken":
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
"""Return a new token with the time offset added."""
|
"""Return a new token with the time offset added."""
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Sentence(TimedText):
|
class Sentence(TimedText):
|
||||||
@@ -28,7 +53,28 @@ class Sentence(TimedText):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transcript(TimedText):
|
class Transcript(TimedText):
|
||||||
pass
|
"""
|
||||||
|
represents a concatenation of several ASRToken
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokens(
|
||||||
|
cls,
|
||||||
|
tokens: List[ASRToken],
|
||||||
|
sep: Optional[str] = None,
|
||||||
|
offset: float = 0
|
||||||
|
) -> "Transcript":
|
||||||
|
sep = sep if sep is not None else ' '
|
||||||
|
text = sep.join(token.text for token in tokens)
|
||||||
|
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||||
|
if tokens:
|
||||||
|
start = offset + tokens[0].start
|
||||||
|
end = offset + tokens[-1].end
|
||||||
|
else:
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
return cls(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeakerSegment(TimedText):
|
class SpeakerSegment(TimedText):
|
||||||
@@ -41,6 +87,34 @@ class SpeakerSegment(TimedText):
|
|||||||
class Translation(TimedText):
|
class Translation(TimedText):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def approximate_cut_at(self, cut_time):
|
||||||
|
"""
|
||||||
|
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||||
|
"""
|
||||||
|
if not self.text or not self.contains_time(cut_time):
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
words = self.text.split()
|
||||||
|
num_words = len(words)
|
||||||
|
if num_words == 0:
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
duration_per_word = self.duration() / num_words
|
||||||
|
|
||||||
|
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||||
|
|
||||||
|
if cut_word_index >= num_words:
|
||||||
|
cut_word_index = num_words -1
|
||||||
|
|
||||||
|
text0 = " ".join(words[:cut_word_index])
|
||||||
|
text1 = " ".join(words[cut_word_index:])
|
||||||
|
|
||||||
|
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||||
|
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||||
|
|
||||||
|
return segment0, segment1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Silence():
|
class Silence():
|
||||||
duration: float
|
duration: float
|
||||||
@@ -51,13 +125,18 @@ class Line(TimedText):
|
|||||||
translation: str = ''
|
translation: str = ''
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
_dict = {
|
||||||
'speaker': int(self.speaker),
|
'speaker': int(self.speaker),
|
||||||
'text': self.text,
|
'text': self.text,
|
||||||
'translation': self.translation,
|
|
||||||
'start': format_time(self.start),
|
'start': format_time(self.start),
|
||||||
'end': format_time(self.end),
|
'end': format_time(self.end),
|
||||||
}
|
}
|
||||||
|
if self.translation:
|
||||||
|
_dict['translation'] = self.translation
|
||||||
|
if self.detected_language:
|
||||||
|
_dict['detected_language'] = self.detected_language
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrontData():
|
class FrontData():
|
||||||
@@ -82,12 +161,16 @@ class FrontData():
|
|||||||
_dict['error'] = self.error
|
_dict['error'] = self.error
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChangeSpeaker:
|
||||||
|
speaker: int
|
||||||
|
start: int
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State():
|
class State():
|
||||||
tokens: list
|
tokens: list
|
||||||
translated_segments: list
|
translated_segments: list
|
||||||
buffer_transcription: str
|
buffer_transcription: str
|
||||||
buffer_diarization: str
|
|
||||||
end_buffer: float
|
end_buffer: float
|
||||||
end_attributed_speaker: float
|
end_attributed_speaker: float
|
||||||
remaining_time_transcription: float
|
remaining_time_transcription: float
|
||||||
|
|||||||
@@ -1,42 +1,63 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||||
from whisperlivekit.timed_objects import Translation
|
from whisperlivekit.timed_objects import Translation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||||
|
|
||||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||||
|
# sentence is not finished.
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TranslationModel():
|
class TranslationModel():
|
||||||
translator: ctranslate2.Translator
|
translator: ctranslate2.Translator
|
||||||
tokenizer: dict
|
device: str
|
||||||
|
tokenizer: dict = field(default_factory=dict)
|
||||||
|
backend_type: str = 'ctranslate2'
|
||||||
|
model_size: str = '600M'
|
||||||
|
|
||||||
def load_model(src_langs):
|
def get_tokenizer(self, input_lang):
|
||||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
if not self.tokenizer.get(input_lang, False):
|
||||||
MODEL_GUY = 'entai2965'
|
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
||||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
f"facebook/nllb-200-distilled-{self.model_size}",
|
||||||
|
src_lang=input_lang,
|
||||||
|
clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
return self.tokenizer[input_lang]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(src_langs, backend='ctranslate2', model_size='600M'):
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
translator = ctranslate2.Translator(MODEL,device=device)
|
MODEL = f'nllb-200-distilled-{model_size}-ctranslate2'
|
||||||
|
if backend=='ctranslate2':
|
||||||
|
MODEL_GUY = 'entai2965'
|
||||||
|
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||||
|
translator = ctranslate2.Translator(MODEL,device=device)
|
||||||
|
elif backend=='transformers':
|
||||||
|
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
|
||||||
tokenizer = dict()
|
tokenizer = dict()
|
||||||
for src_lang in src_langs:
|
for src_lang in src_langs:
|
||||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
if src_lang != 'auto':
|
||||||
return TranslationModel(
|
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||||
translator=translator,
|
|
||||||
tokenizer=tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate(input, translation_model, tgt_lang):
|
translation_model = TranslationModel(
|
||||||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
translator=translator,
|
||||||
target_prefix = [tgt_lang]
|
tokenizer=tokenizer,
|
||||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
backend_type=backend,
|
||||||
target = results[0].hypotheses[0][1:]
|
device = device,
|
||||||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
model_size = model_size
|
||||||
|
)
|
||||||
|
for src_lang in src_langs:
|
||||||
|
if src_lang != 'auto':
|
||||||
|
translation_model.get_tokenizer(src_lang)
|
||||||
|
return translation_model
|
||||||
|
|
||||||
class OnlineTranslation:
|
class OnlineTranslation:
|
||||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||||
@@ -59,27 +80,38 @@ class OnlineTranslation:
|
|||||||
self.commited.extend(self.buffer[:i])
|
self.commited.extend(self.buffer[:i])
|
||||||
self.buffer = results[i:]
|
self.buffer = results[i:]
|
||||||
|
|
||||||
def translate(self, input, input_lang=None, output_lang=None):
|
def translate(self, input, input_lang, output_lang):
|
||||||
if not input:
|
if not input:
|
||||||
return ""
|
return ""
|
||||||
if input_lang is None:
|
|
||||||
input_lang = self.input_languages[0]
|
|
||||||
if output_lang is None:
|
|
||||||
output_lang = self.output_languages[0]
|
|
||||||
nllb_output_lang = get_nllb_code(output_lang)
|
nllb_output_lang = get_nllb_code(output_lang)
|
||||||
|
|
||||||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
||||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||||
target = results[0].hypotheses[0][1:]
|
|
||||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
if self.translation_model.backend_type == 'ctranslate2':
|
||||||
return results
|
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||||
|
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||||
|
target = results[0].hypotheses[0][1:]
|
||||||
|
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||||
|
else:
|
||||||
|
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||||
|
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||||
|
return result
|
||||||
|
|
||||||
def translate_tokens(self, tokens):
|
def translate_tokens(self, tokens):
|
||||||
if tokens:
|
if tokens:
|
||||||
text = ' '.join([token.text for token in tokens])
|
text = ' '.join([token.text for token in tokens])
|
||||||
start = tokens[0].start
|
start = tokens[0].start
|
||||||
end = tokens[-1].end
|
end = tokens[-1].end
|
||||||
translated_text = self.translate(text)
|
if self.input_languages[0] == 'auto':
|
||||||
|
input_lang = tokens[0].detected_language
|
||||||
|
else:
|
||||||
|
input_lang = self.input_languages[0]
|
||||||
|
|
||||||
|
translated_text = self.translate(text,
|
||||||
|
input_lang,
|
||||||
|
self.output_languages[0]
|
||||||
|
)
|
||||||
translation = Translation(
|
translation = Translation(
|
||||||
text=translated_text,
|
text=translated_text,
|
||||||
start=start,
|
start=start,
|
||||||
@@ -89,7 +121,6 @@ class OnlineTranslation:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def insert_tokens(self, tokens):
|
def insert_tokens(self, tokens):
|
||||||
self.buffer.extend(tokens)
|
self.buffer.extend(tokens)
|
||||||
pass
|
pass
|
||||||
@@ -99,7 +130,7 @@ class OnlineTranslation:
|
|||||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||||
return self.validated + [self.translation_remaining]
|
return self.validated + [self.translation_remaining]
|
||||||
while i < len(self.buffer):
|
while i < len(self.buffer):
|
||||||
if self.buffer[i].text in PUNCTUATION_MARKS:
|
if self.buffer[i].is_punctuation():
|
||||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||||
self.validated.append(translation_sentence)
|
self.validated.append(translation_sentence)
|
||||||
self.buffer = self.buffer[i+1:]
|
self.buffer = self.buffer[i+1:]
|
||||||
@@ -110,6 +141,10 @@ class OnlineTranslation:
|
|||||||
self.len_processed_buffer = len(self.buffer)
|
self.len_processed_buffer = len(self.buffer)
|
||||||
return self.validated + [self.translation_remaining]
|
return self.validated + [self.translation_remaining]
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration: float):
|
||||||
|
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||||
|
self.buffer = []
|
||||||
|
self.validated += [self.translation_remaining]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
output_lang = 'fr'
|
output_lang = 'fr'
|
||||||
@@ -122,16 +157,13 @@ if __name__ == '__main__':
|
|||||||
test = test_string.split(' ')
|
test = test_string.split(' ')
|
||||||
step = len(test) // 3
|
step = len(test) // 3
|
||||||
|
|
||||||
shared_model = load_model([input_lang])
|
shared_model = load_model([input_lang], backend='ctranslate2')
|
||||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||||
|
|
||||||
|
beg_inference = time.time()
|
||||||
for id in range(5):
|
for id in range(5):
|
||||||
val = test[id*step : (id+1)*step]
|
val = test[id*step : (id+1)*step]
|
||||||
val_str = ' '.join(val)
|
val_str = ' '.join(val)
|
||||||
result = online_translation.translate(val_str)
|
result = online_translation.translate(val_str)
|
||||||
print(result)
|
print(result)
|
||||||
|
print('inference time:', time.time() - beg_inference)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# print(result)
|
|
||||||
@@ -346,7 +346,7 @@ label {
|
|||||||
|
|
||||||
.label_diarization {
|
.label_diarization {
|
||||||
background-color: var(--chip-bg);
|
background-color: var(--chip-bg);
|
||||||
border-radius: 8px 8px 8px 8px;
|
border-radius: 100px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
@@ -358,7 +358,7 @@ label {
|
|||||||
|
|
||||||
.label_transcription {
|
.label_transcription {
|
||||||
background-color: var(--chip-bg);
|
background-color: var(--chip-bg);
|
||||||
border-radius: 8px 8px 8px 8px;
|
border-radius: 100px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
@@ -370,16 +370,20 @@ label {
|
|||||||
|
|
||||||
.label_translation {
|
.label_translation {
|
||||||
background-color: var(--chip-bg);
|
background-color: var(--chip-bg);
|
||||||
|
display: inline-flex;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
padding: 4px 8px;
|
padding: 4px 8px;
|
||||||
margin-top: 4px;
|
margin-top: 4px;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
color: var(--text);
|
color: var(--text);
|
||||||
display: flex;
|
|
||||||
align-items: flex-start;
|
align-items: flex-start;
|
||||||
gap: 4px;
|
gap: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lag-diarization-value {
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
.label_translation img {
|
.label_translation img {
|
||||||
margin-top: 2px;
|
margin-top: 2px;
|
||||||
}
|
}
|
||||||
@@ -391,7 +395,7 @@ label {
|
|||||||
|
|
||||||
#timeInfo {
|
#timeInfo {
|
||||||
color: var(--muted);
|
color: var(--muted);
|
||||||
margin-left: 10px;
|
margin-left: 0px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.textcontent {
|
.textcontent {
|
||||||
@@ -438,7 +442,6 @@ label {
|
|||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
border-radius: 30px;
|
border-radius: 30px;
|
||||||
padding: 2px 10px;
|
padding: 2px 10px;
|
||||||
display: none;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.loading {
|
.loading {
|
||||||
@@ -515,3 +518,33 @@ label {
|
|||||||
padding: 10px;
|
padding: 10px;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.label_language {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
margin-bottom: 0px;
|
||||||
|
margin-top: 5px;
|
||||||
|
height: 18.5px;
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 8px;
|
||||||
|
margin-left: 10px;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.speaker-badge {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
margin-left: -5px;
|
||||||
|
border-radius: 50%;
|
||||||
|
font-size: 11px;
|
||||||
|
line-height: 1;
|
||||||
|
font-weight: 800;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ let lastReceivedData = null;
|
|||||||
let lastSignature = null;
|
let lastSignature = null;
|
||||||
let availableMicrophones = [];
|
let availableMicrophones = [];
|
||||||
let selectedMicrophoneId = null;
|
let selectedMicrophoneId = null;
|
||||||
|
let serverUseAudioWorklet = null;
|
||||||
|
let configReadyResolve;
|
||||||
|
const configReady = new Promise((r) => (configReadyResolve = r));
|
||||||
|
|
||||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
@@ -37,6 +40,11 @@ const timerElement = document.querySelector(".timer");
|
|||||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||||
|
|
||||||
|
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
|
||||||
|
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
|
||||||
|
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
|
||||||
|
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
|
||||||
|
|
||||||
function getWaveStroke() {
|
function getWaveStroke() {
|
||||||
const styles = getComputedStyle(document.documentElement);
|
const styles = getComputedStyle(document.documentElement);
|
||||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||||
@@ -228,6 +236,14 @@ function setupWebSocket() {
|
|||||||
|
|
||||||
websocket.onmessage = (event) => {
|
websocket.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
|
if (data.type === "config") {
|
||||||
|
serverUseAudioWorklet = !!data.useAudioWorklet;
|
||||||
|
statusText.textContent = serverUseAudioWorklet
|
||||||
|
? "Connected. Using AudioWorklet (PCM)."
|
||||||
|
: "Connected. Using MediaRecorder (WebM).";
|
||||||
|
if (configReadyResolve) configReadyResolve();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (data.type === "ready_to_stop") {
|
if (data.type === "ready_to_stop") {
|
||||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||||
@@ -295,7 +311,7 @@ function renderLinesWithBuffer(
|
|||||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||||
const signature = JSON.stringify({
|
const signature = JSON.stringify({
|
||||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end })),
|
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||||
buffer_transcription: buffer_transcription || "",
|
buffer_transcription: buffer_transcription || "",
|
||||||
buffer_diarization: buffer_diarization || "",
|
buffer_diarization: buffer_diarization || "",
|
||||||
status: current_status,
|
status: current_status,
|
||||||
@@ -324,24 +340,22 @@ function renderLinesWithBuffer(
|
|||||||
|
|
||||||
let speakerLabel = "";
|
let speakerLabel = "";
|
||||||
if (item.speaker === -2) {
|
if (item.speaker === -2) {
|
||||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
} else if (item.speaker == 0 && !isFinalizing) {
|
} else if (item.speaker == 0 && !isFinalizing) {
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||||
remaining_time_diarization
|
remaining_time_diarization
|
||||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||||
} else if (item.speaker !== 0) {
|
} else if (item.speaker !== 0) {
|
||||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||||
|
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
|
||||||
|
if (item.detected_language) {
|
||||||
|
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let currentLineText = item.text || "";
|
let currentLineText = item.text || "";
|
||||||
|
|
||||||
if (item.translation) {
|
|
||||||
currentLineText += `<div class="label_translation">
|
|
||||||
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" />
|
|
||||||
<span>${item.translation}</span>
|
|
||||||
</div>`;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (idx === lines.length - 1) {
|
if (idx === lines.length - 1) {
|
||||||
if (!isFinalizing && item.speaker !== -2) {
|
if (!isFinalizing && item.speaker !== -2) {
|
||||||
if (remaining_time_transcription > 0) {
|
if (remaining_time_transcription > 0) {
|
||||||
@@ -375,6 +389,13 @@ function renderLinesWithBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (item.translation) {
|
||||||
|
currentLineText += `<div class="label_translation">
|
||||||
|
${translationIcon}
|
||||||
|
<span>${item.translation}</span>
|
||||||
|
</div>`;
|
||||||
|
}
|
||||||
|
|
||||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||||
: `<p>${speakerLabel}<br/></p>`;
|
: `<p>${speakerLabel}<br/></p>`;
|
||||||
@@ -459,38 +480,54 @@ async function startRecording() {
|
|||||||
microphone = audioContext.createMediaStreamSource(stream);
|
microphone = audioContext.createMediaStreamSource(stream);
|
||||||
microphone.connect(analyser);
|
microphone.connect(analyser);
|
||||||
|
|
||||||
if (!audioContext.audioWorklet) {
|
if (serverUseAudioWorklet) {
|
||||||
throw new Error("AudioWorklet is not supported in this browser");
|
if (!audioContext.audioWorklet) {
|
||||||
}
|
throw new Error("AudioWorklet is not supported in this browser");
|
||||||
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
|
|
||||||
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
|
|
||||||
microphone.connect(workletNode);
|
|
||||||
|
|
||||||
recorderWorker = new Worker("/web/recorder_worker.js");
|
|
||||||
recorderWorker.postMessage({
|
|
||||||
command: "init",
|
|
||||||
config: {
|
|
||||||
sampleRate: audioContext.sampleRate,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
recorderWorker.onmessage = (e) => {
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
websocket.send(e.data.buffer);
|
|
||||||
}
|
}
|
||||||
};
|
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
|
||||||
|
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
|
||||||
|
microphone.connect(workletNode);
|
||||||
|
|
||||||
workletNode.port.onmessage = (e) => {
|
recorderWorker = new Worker("/web/recorder_worker.js");
|
||||||
const data = e.data;
|
recorderWorker.postMessage({
|
||||||
const ab = data instanceof ArrayBuffer ? data : data.buffer;
|
command: "init",
|
||||||
recorderWorker.postMessage(
|
config: {
|
||||||
{
|
sampleRate: audioContext.sampleRate,
|
||||||
command: "record",
|
|
||||||
buffer: ab,
|
|
||||||
},
|
},
|
||||||
[ab]
|
});
|
||||||
);
|
|
||||||
};
|
recorderWorker.onmessage = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(e.data.buffer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
workletNode.port.onmessage = (e) => {
|
||||||
|
const data = e.data;
|
||||||
|
const ab = data instanceof ArrayBuffer ? data : data.buffer;
|
||||||
|
recorderWorker.postMessage(
|
||||||
|
{
|
||||||
|
command: "record",
|
||||||
|
buffer: ab,
|
||||||
|
},
|
||||||
|
[ab]
|
||||||
|
);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||||
|
} catch (e) {
|
||||||
|
recorder = new MediaRecorder(stream);
|
||||||
|
}
|
||||||
|
recorder.ondataavailable = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
if (e.data && e.data.size > 0) {
|
||||||
|
websocket.send(e.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recorder.start(chunkDuration);
|
||||||
|
}
|
||||||
|
|
||||||
startTime = Date.now();
|
startTime = Date.now();
|
||||||
timerInterval = setInterval(updateTimer, 1000);
|
timerInterval = setInterval(updateTimer, 1000);
|
||||||
@@ -528,6 +565,14 @@ async function stopRecording() {
|
|||||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (recorder) {
|
||||||
|
try {
|
||||||
|
recorder.stop();
|
||||||
|
} catch (e) {
|
||||||
|
}
|
||||||
|
recorder = null;
|
||||||
|
}
|
||||||
|
|
||||||
if (recorderWorker) {
|
if (recorderWorker) {
|
||||||
recorderWorker.terminate();
|
recorderWorker.terminate();
|
||||||
recorderWorker = null;
|
recorderWorker = null;
|
||||||
@@ -586,9 +631,11 @@ async function toggleRecording() {
|
|||||||
console.log("Connecting to WebSocket");
|
console.log("Connecting to WebSocket");
|
||||||
try {
|
try {
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
await configReady;
|
||||||
await startRecording();
|
await startRecording();
|
||||||
} else {
|
} else {
|
||||||
await setupWebSocket();
|
await setupWebSocket();
|
||||||
|
await configReady;
|
||||||
await startRecording();
|
await startRecording();
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
1
whisperlivekit/web/src/language.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 976 B |
1
whisperlivekit/web/src/silence.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 984 B |
1
whisperlivekit/web/src/speaker.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 592 B |