mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12973711f6 | ||
|
|
909ac9dd41 | ||
|
|
d94a07d417 | ||
|
|
b32dd8bfc4 | ||
|
|
9feb0e597b | ||
|
|
9dab84a573 | ||
|
|
d089c7fce0 | ||
|
|
253a080df5 | ||
|
|
0c6e4b2aee | ||
|
|
e14bbde77d | ||
|
|
7496163467 | ||
|
|
696a94d1ce | ||
|
|
2699b0974c | ||
|
|
90c0250ba4 | ||
|
|
eb96153ffd | ||
|
|
47e3eb9b5b | ||
|
|
b8b07adeef | ||
|
|
d0e9e37ef6 | ||
|
|
820f92d8cb | ||
|
|
e42523af84 | ||
|
|
e2184d5e06 | ||
|
|
7fe0353260 | ||
|
|
0f2eba507e | ||
|
|
55e08474f3 | ||
|
|
28bdc52e1d | ||
|
|
e4221fa6c3 | ||
|
|
1652db9a2d | ||
|
|
601f17653a | ||
|
|
7718190fcd | ||
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f |
@@ -15,7 +15,7 @@ Thank you for considering contributing ! We appreciate your time and effort to h
|
|||||||
|
|
||||||
## Opening Issues
|
## Opening Issues
|
||||||
|
|
||||||
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
|
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||||
|
|
||||||
- **Bug Reports:**
|
- **Bug Reports:**
|
||||||
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||||
@@ -43,4 +43,4 @@ We welcome and appreciate contributions! To ensure a smooth review process, plea
|
|||||||
|
|
||||||
## Thank You
|
## Thank You
|
||||||
|
|
||||||
Your contributions make diart better for everyone. Thank you for your time and dedication!
|
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!
|
||||||
|
|||||||
@@ -21,10 +21,12 @@ RUN apt-get update && \
|
|||||||
python3 \
|
python3 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
git && \
|
git \
|
||||||
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
@@ -79,4 +81,4 @@ EXPOSE 8000
|
|||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
# Default args
|
# Default args
|
||||||
CMD ["--model", "tiny.en"]
|
CMD ["--model", "base"]
|
||||||
247
README.md
247
README.md
@@ -4,7 +4,7 @@
|
|||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
|
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
@@ -13,124 +13,94 @@
|
|||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
Real-time speech transcription directly to your browser, with a ready-to-use backend+server and a simple frontend. ✨
|
||||||
|
|
||||||
|
#### Powered by Leading Research:
|
||||||
|
|
||||||
|
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription with AlignAtt policy
|
||||||
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription with LocalAgreement policy
|
||||||
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||||
|
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||||
|
|
||||||
|
|
||||||
|
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
|
||||||
|
|
||||||
|
|
||||||
### Architecture
|
### Architecture
|
||||||
|
|
||||||
WhisperLiveKit consists of three main components:
|
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||||
|
|
||||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the [provided template](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
|
||||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
|
||||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
|
||||||
|
|
||||||
|
### Installation & Quick Start
|
||||||
### Key Features
|
|
||||||
|
|
||||||
- **Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
|
||||||
- **Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
|
||||||
- **Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
|
||||||
- **Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
|
||||||
- **Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
|
||||||
- **Buffering Preview** – Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
|
|
||||||
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
|
||||||
- **SimulStreaming Backend** - [Dual-licensed](https://github.com/ufal/SimulStreaming#-licence-and-contributions) - Ultra-low latency transcription using SOTA AlignAtt policy.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install the package
|
|
||||||
pip install whisperlivekit
|
pip install whisperlivekit
|
||||||
|
|
||||||
# Start the transcription server
|
|
||||||
whisperlivekit-server --model tiny.en
|
|
||||||
|
|
||||||
# Open your browser at http://localhost:8000 to see the interface.
|
|
||||||
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it! Start speaking and watch your words appear on screen.
|
> **FFmpeg is required** and must be installed before using WhisperLiveKit
|
||||||
|
>
|
||||||
|
> | OS | How to install |
|
||||||
|
> |-----------|-------------|
|
||||||
|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||||
|
> | MacOS | `brew install ffmpeg` |
|
||||||
|
> | Windows | Download .exe from https://ffmpeg.org/download.html and add to PATH |
|
||||||
|
|
||||||
## Installation
|
#### Quick Start
|
||||||
|
1. **Start the transcription server:**
|
||||||
|
```bash
|
||||||
|
whisperlivekit-server --model base --language en
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||||
#Install from PyPI (Recommended)
|
|
||||||
pip install whisperlivekit
|
|
||||||
|
|
||||||
#Install from Source
|
|
||||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
|
||||||
cd WhisperLiveKit
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
### FFmpeg Dependency
|
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||||
|
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ubuntu/Debian
|
|
||||||
sudo apt install ffmpeg
|
|
||||||
|
|
||||||
# macOS
|
#### Optional Dependencies
|
||||||
brew install ffmpeg
|
|
||||||
|
|
||||||
# Windows
|
| Optional | `pip install` |
|
||||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
|-----------|-------------|
|
||||||
```
|
| Speaker diarization | `whisperlivekit[diarization]` |
|
||||||
|
| Original Whisper backend | `whisperlivekit[whisper]` |
|
||||||
|
| Improved timestamps backend | `whisperlivekit[whisper-timestamped]` |
|
||||||
|
| Apple Silicon optimization backend | `whisperlivekit[mlx-whisper]` |
|
||||||
|
| OpenAI API backend | `whisperlivekit[openai]` |
|
||||||
|
|
||||||
### Optional Dependencies
|
See **Parameters & Configuration** below on how to use them.
|
||||||
|
|
||||||
```bash
|
|
||||||
# Voice Activity Controller (prevents hallucinations)
|
> **Pyannote Models Setup** For diarization, you need access to pyannote.audio models:
|
||||||
pip install torch
|
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||||
|
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||||
# Sentence-based buffer trimming
|
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||||
pip install mosestokenizer wtpsplit
|
>4. Login with HuggingFace:
|
||||||
pip install tokenize_uk # If you work with Ukrainian text
|
> ```bash
|
||||||
|
> huggingface-cli login
|
||||||
# Speaker diarization
|
> ```
|
||||||
pip install diart
|
|
||||||
|
|
||||||
# Alternative Whisper backends (default is faster-whisper)
|
|
||||||
pip install whisperlivekit[whisper] # Original Whisper
|
|
||||||
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
|
|
||||||
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
|
|
||||||
pip install whisperlivekit[openai] # OpenAI API
|
|
||||||
pip install whisperlivekit[simulstreaming]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🎹 Pyannote Models Setup
|
|
||||||
|
|
||||||
For diarization, you need access to pyannote.audio models:
|
|
||||||
|
|
||||||
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
|
||||||
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
|
||||||
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
|
||||||
4. Login with HuggingFace:
|
|
||||||
```bash
|
|
||||||
pip install huggingface_hub
|
|
||||||
huggingface-cli login
|
|
||||||
```
|
|
||||||
|
|
||||||
## 💻 Usage Examples
|
## 💻 Usage Examples
|
||||||
|
|
||||||
### Command-line Interface
|
#### Command-line Interface
|
||||||
|
|
||||||
Start the transcription server with various options:
|
Start the transcription server with various options:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Basic server with English model
|
# SimulStreaming backend for ultra-low latency
|
||||||
whisperlivekit-server --model tiny.en
|
whisperlivekit-server --backend simulstreaming --model large-v3
|
||||||
|
|
||||||
# Advanced configuration with diarization
|
# Advanced configuration with diarization
|
||||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
|
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||||
|
|
||||||
# SimulStreaming backend for ultra-low latency
|
|
||||||
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Python API Integration (Backend)
|
#### Python API Integration (Backend)
|
||||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||||
@@ -145,14 +115,10 @@ transcription_engine = None
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||||
# You can also load from command-line arguments using parse_args()
|
|
||||||
# args = parse_args()
|
|
||||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# Process WebSocket connections
|
|
||||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||||
async for response in results_generator:
|
async for response in results_generator:
|
||||||
await websocket.send_json(response)
|
await websocket.send_json(response)
|
||||||
@@ -172,43 +138,36 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await audio_processor.process_audio(message)
|
await audio_processor.process_audio(message)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Frontend Implementation
|
#### Frontend Implementation
|
||||||
|
|
||||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html), or load its content using `get_web_interface_html()` :
|
The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
|
||||||
|
|
||||||
```python
|
|
||||||
from whisperlivekit import get_web_interface_html
|
|
||||||
html_content = get_web_interface_html()
|
|
||||||
```
|
|
||||||
|
|
||||||
## ⚙️ Configuration Reference
|
### ⚙️ Parameters & Configuration
|
||||||
|
|
||||||
WhisperLiveKit offers extensive configuration options:
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--host` | Server host address | `localhost` |
|
| `--model` | Whisper model size. | `small` |
|
||||||
| `--port` | Server port | `8000` |
|
|
||||||
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
|
|
||||||
| `--language` | Source language code or `auto` | `en` |
|
| `--language` | Source language code or `auto` | `en` |
|
||||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||||
| `--backend` | Processing backend | `faster-whisper` |
|
| `--backend` | Processing backend | `simulstreaming` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
|
||||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
|
||||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
|
||||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||||
| `--vac` | Use Voice Activity Controller | `False` |
|
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
|
| `--host` | Server host address | `localhost` |
|
||||||
|
| `--port` | Server port | `8000` |
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
|
||||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
|
||||||
|
|
||||||
**SimulStreaming-specific Options:**
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| WhisperStreaming backend options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||||
|
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||||
|
|
||||||
|
|
||||||
|
| SimulStreaming backend options | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||||
@@ -221,68 +180,57 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||||
|
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||||
|
|
||||||
## 🔧 How It Works
|
| Diarization options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
|
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
|
### 🚀 Deployment Guide
|
||||||
2. **Streaming**: Audio chunks are sent to the server via WebSocket
|
|
||||||
3. **Processing**: Server decodes audio with FFmpeg and streams into the model for transcription
|
|
||||||
4. **Real-time Output**: Partial transcriptions appear immediately in light gray (the 'aperçu') and finalized text appears in normal color
|
|
||||||
|
|
||||||
## 🚀 Deployment Guide
|
|
||||||
|
|
||||||
To deploy WhisperLiveKit in production:
|
To deploy WhisperLiveKit in production:
|
||||||
|
|
||||||
1. **Server Setup** (Backend):
|
1. **Server Setup**: Install production ASGI server & launch with multiple workers
|
||||||
```bash
|
```bash
|
||||||
# Install production ASGI server
|
|
||||||
pip install uvicorn gunicorn
|
pip install uvicorn gunicorn
|
||||||
|
|
||||||
# Launch with multiple workers
|
|
||||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Frontend Integration**:
|
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
|
||||||
- Host your customized version of the example HTML/JS in your web application
|
|
||||||
- Ensure WebSocket connection points to your server's address
|
|
||||||
|
|
||||||
3. **Nginx Configuration** (recommended for production):
|
3. **Nginx Configuration** (recommended for production):
|
||||||
```nginx
|
```nginx
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name your-domain.com;
|
server_name your-domain.com;
|
||||||
|
location / {
|
||||||
location / {
|
proxy_pass http://localhost:8000;
|
||||||
proxy_pass http://localhost:8000;
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
proxy_set_header Connection "upgrade";
|
||||||
proxy_set_header Connection "upgrade";
|
proxy_set_header Host $host;
|
||||||
proxy_set_header Host $host;
|
|
||||||
}}
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||||
|
|
||||||
### 🐋 Docker
|
### 🐋 Docker
|
||||||
|
|
||||||
A basic Dockerfile is provided which allows re-use of Python package installation options. ⚠️ For **large** models, ensure that your **docker runtime** has enough **memory** available. See below usage examples:
|
A Dockerfile is provided which allows re-use of Python package installation options. Create a reusable image with only the basics and then run as a named container:
|
||||||
|
|
||||||
|
|
||||||
#### All defaults
|
|
||||||
- Create a reusable image with only the basics and then run as a named container:
|
|
||||||
```bash
|
```bash
|
||||||
docker build -t whisperlivekit-defaults .
|
docker build -t whisperlivekit-defaults .
|
||||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults --model base
|
||||||
docker start -i whisperlivekit
|
docker start -i whisperlivekit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Note**: For **large** models, ensure that your **docker runtime** has enough **memory** available
|
||||||
|
|
||||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||||
|
|
||||||
#### Customization
|
#### Customization
|
||||||
- Customize the container options:
|
|
||||||
```bash
|
|
||||||
docker build -t whisperlivekit-defaults .
|
|
||||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
|
||||||
docker start -i whisperlivekit-base
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||||
@@ -291,10 +239,3 @@ docker start -i whisperlivekit-base
|
|||||||
|
|
||||||
## 🔮 Use Cases
|
## 🔮 Use Cases
|
||||||
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||||
|
|
||||||
## 🙏 Acknowledgments
|
|
||||||
|
|
||||||
We extend our gratitude to the original authors of:
|
|
||||||
|
|
||||||
| [Whisper Streaming](https://github.com/ufal/whisper_streaming) | [SimulStreaming](https://github.com/ufal/SimulStreaming) | [Diart](https://github.com/juanmc2005/diart) | [OpenAI Whisper](https://github.com/openai/whisper) |
|
|
||||||
| -------- | ------- | -------- | ------- |
|
|
||||||
|
|||||||
BIN
architecture.png
Normal file
BIN
architecture.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 388 KiB |
BIN
demo.png
BIN
demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 438 KiB After Width: | Height: | Size: 423 KiB |
56
pyproject.toml
Normal file
56
pyproject.toml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "whisperlivekit"
|
||||||
|
version = "0.2.6"
|
||||||
|
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{ name = "Quentin Fuxa" }
|
||||||
|
]
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"fastapi",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"faster-whisper",
|
||||||
|
"uvicorn",
|
||||||
|
"websockets",
|
||||||
|
"torch",
|
||||||
|
"tqdm",
|
||||||
|
"tiktoken",
|
||||||
|
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diarization = ["diart"]
|
||||||
|
sentence = ["mosestokenizer", "wtpsplit"]
|
||||||
|
whisper = ["whisper"]
|
||||||
|
whisper-timestamped = ["whisper-timestamped"]
|
||||||
|
mlx-whisper = ["mlx-whisper"]
|
||||||
|
openai = ["openai"]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
|
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
55
setup.py
55
setup.py
@@ -1,55 +0,0 @@
|
|||||||
from setuptools import setup, find_packages
|
|
||||||
setup(
|
|
||||||
name="whisperlivekit",
|
|
||||||
version="0.2.1",
|
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
author="Quentin Fuxa",
|
|
||||||
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=[
|
|
||||||
"fastapi",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"faster-whisper",
|
|
||||||
"uvicorn",
|
|
||||||
"websockets",
|
|
||||||
],
|
|
||||||
extras_require={
|
|
||||||
"diarization": ["diart"],
|
|
||||||
"vac": ["torch"],
|
|
||||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
|
||||||
"whisper": ["whisper"],
|
|
||||||
"whisper-timestamped": ["whisper-timestamped"],
|
|
||||||
"mlx-whisper": ["mlx-whisper"],
|
|
||||||
"openai": ["openai"],
|
|
||||||
"simulstreaming": [
|
|
||||||
"torch",
|
|
||||||
"tqdm",
|
|
||||||
"tiktoken",
|
|
||||||
"numpy<2.0.0",
|
|
||||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
package_data={
|
|
||||||
'whisperlivekit': ['web/*.html'],
|
|
||||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
|
||||||
'whisperlivekit.simul_whisper.whisper.assets': ['*.tiktoken', '*.npz'],
|
|
||||||
},
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
|
||||||
],
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
|
||||||
],
|
|
||||||
python_requires=">=3.9",
|
|
||||||
)
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from .download_simulstreaming_backend import download_simulstreaming_backend
|
|
||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .core import TranscriptionEngine
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken, Silence
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
from whisperlivekit.core import TranscriptionEngine, online_factory
|
||||||
from whisperlivekit.core import TranscriptionEngine
|
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
|
from whisperlivekit.remove_silences import handle_silences
|
||||||
|
from whisperlivekit.trail_repetition import trim_tail_repetition
|
||||||
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||||
# Set up logging once
|
# Set up logging once
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -46,17 +47,19 @@ class AudioProcessor:
|
|||||||
self.last_ffmpeg_activity = time()
|
self.last_ffmpeg_activity = time()
|
||||||
self.ffmpeg_health_check_interval = 5
|
self.ffmpeg_health_check_interval = 5
|
||||||
self.ffmpeg_max_idle_time = 10
|
self.ffmpeg_max_idle_time = 10
|
||||||
|
self.debug = False
|
||||||
|
|
||||||
# State management
|
# State management
|
||||||
self.is_stopping = False
|
self.is_stopping = False
|
||||||
|
self.silence = False
|
||||||
|
self.silence_duration = 0.0
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.buffer_transcription = ""
|
self.buffer_transcription = ""
|
||||||
self.buffer_diarization = ""
|
self.buffer_diarization = ""
|
||||||
self.full_transcription = ""
|
|
||||||
self.end_buffer = 0
|
self.end_buffer = 0
|
||||||
self.end_attributed_speaker = 0
|
self.end_attributed_speaker = 0
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.beg_loop = time()
|
self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio
|
||||||
self.sep = " " # Default separator
|
self.sep = " " # Default separator
|
||||||
self.last_response_content = ""
|
self.last_response_content = ""
|
||||||
|
|
||||||
@@ -64,7 +67,12 @@ class AudioProcessor:
|
|||||||
self.asr = models.asr
|
self.asr = models.asr
|
||||||
self.tokenizer = models.tokenizer
|
self.tokenizer = models.tokenizer
|
||||||
self.diarization = models.diarization
|
self.diarization = models.diarization
|
||||||
|
self.vac_model = models.vac_model
|
||||||
|
if self.args.vac:
|
||||||
|
self.vac = FixedVADIterator(models.vac_model)
|
||||||
|
else:
|
||||||
|
self.vac = None
|
||||||
|
|
||||||
self.ffmpeg_manager = FFmpegManager(
|
self.ffmpeg_manager = FFmpegManager(
|
||||||
sample_rate=self.sample_rate,
|
sample_rate=self.sample_rate,
|
||||||
channels=self.channels
|
channels=self.channels
|
||||||
@@ -96,13 +104,23 @@ class AudioProcessor:
|
|||||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
async def update_transcription(self, new_tokens, buffer, end_buffer, sep):
|
||||||
"""Thread-safe update of transcription with new data."""
|
"""Thread-safe update of transcription with new data."""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.tokens.extend(new_tokens)
|
self.tokens.extend(new_tokens)
|
||||||
|
|
||||||
|
# self.tokens, has_been_trimmed = trim_tail_repetition(
|
||||||
|
# self.tokens,
|
||||||
|
# key=lambda t: t.text.strip().lower(),
|
||||||
|
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
|
||||||
|
# max_tail=200,
|
||||||
|
# prefer="longest", # prefer removing the longest repeated phrase
|
||||||
|
# keep=1
|
||||||
|
# )
|
||||||
|
# if has_been_trimmed:
|
||||||
|
# print('HAS BEEN TRIMMED !')
|
||||||
self.buffer_transcription = buffer
|
self.buffer_transcription = buffer
|
||||||
self.end_buffer = end_buffer
|
self.end_buffer = end_buffer
|
||||||
self.full_transcription = full_transcription
|
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
|
|
||||||
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
||||||
@@ -129,12 +147,12 @@ class AudioProcessor:
|
|||||||
# Calculate remaining times
|
# Calculate remaining times
|
||||||
remaining_transcription = 0
|
remaining_transcription = 0
|
||||||
if self.end_buffer > 0:
|
if self.end_buffer > 0:
|
||||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||||
|
|
||||||
remaining_diarization = 0
|
remaining_diarization = 0
|
||||||
if self.tokens:
|
if self.tokens:
|
||||||
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
||||||
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tokens": self.tokens.copy(),
|
"tokens": self.tokens.copy(),
|
||||||
@@ -153,7 +171,6 @@ class AudioProcessor:
|
|||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.buffer_transcription = self.buffer_diarization = ""
|
self.buffer_transcription = self.buffer_diarization = ""
|
||||||
self.end_buffer = self.end_attributed_speaker = 0
|
self.end_buffer = self.end_attributed_speaker = 0
|
||||||
self.full_transcription = self.last_response_content = ""
|
|
||||||
self.beg_loop = time()
|
self.beg_loop = time()
|
||||||
|
|
||||||
async def ffmpeg_stdout_reader(self):
|
async def ffmpeg_stdout_reader(self):
|
||||||
@@ -192,12 +209,6 @@ class AudioProcessor:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.pcm_buffer.extend(chunk)
|
self.pcm_buffer.extend(chunk)
|
||||||
|
|
||||||
# Send to diarization if enabled
|
|
||||||
if self.args.diarization and self.diarization_queue:
|
|
||||||
await self.diarization_queue.put(
|
|
||||||
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process when enough data
|
# Process when enough data
|
||||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||||
@@ -211,14 +222,44 @@ class AudioProcessor:
|
|||||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
||||||
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
||||||
|
|
||||||
# Send to transcription if enabled
|
res = None
|
||||||
if self.args.transcription and self.transcription_queue:
|
end_of_audio = False
|
||||||
await self.transcription_queue.put(pcm_array.copy())
|
silence_buffer = None
|
||||||
|
|
||||||
|
if self.args.vac:
|
||||||
|
res = self.vac(pcm_array)
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
if res.get('end', 0) > res.get('start', 0):
|
||||||
|
end_of_audio = True
|
||||||
|
elif self.silence: #end of silence
|
||||||
|
self.silence = False
|
||||||
|
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||||
|
|
||||||
|
if silence_buffer:
|
||||||
|
if self.args.transcription and self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(silence_buffer)
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(silence_buffer)
|
||||||
|
|
||||||
|
if not self.silence:
|
||||||
|
if self.args.transcription and self.transcription_queue:
|
||||||
|
await self.transcription_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
|
if self.args.diarization and self.diarization_queue:
|
||||||
|
await self.diarization_queue.put(pcm_array.copy())
|
||||||
|
|
||||||
|
self.silence_duration = 0.0
|
||||||
|
if end_of_audio:
|
||||||
|
self.silence = True
|
||||||
|
self.start_silence = time()
|
||||||
|
|
||||||
# Sleep if no processing is happening
|
# Sleep if no processing is happening
|
||||||
if not self.args.transcription and not self.args.diarization:
|
if not self.args.transcription and not self.args.diarization:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
@@ -240,36 +281,48 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def transcription_processor(self):
|
async def transcription_processor(self):
|
||||||
"""Process audio chunks for transcription."""
|
"""Process audio chunks for transcription."""
|
||||||
self.full_transcription = ""
|
|
||||||
self.sep = self.online.asr.sep
|
self.sep = self.online.asr.sep
|
||||||
cumulative_pcm_duration_stream_time = 0.0
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
pcm_array = await self.transcription_queue.get()
|
item = await self.transcription_queue.get()
|
||||||
if pcm_array is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.online: # Should not happen if queue is used
|
if not self.online:
|
||||||
logger.warning("Transcription processor: self.online not initialized.")
|
logger.warning("Transcription processor: self.online not initialized.")
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||||
|
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||||
logger.info(
|
if type(item) is Silence:
|
||||||
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, "
|
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||||
f"lag={transcription_lag_s:.2f}s."
|
if self.tokens:
|
||||||
)
|
asr_processing_logs += " | last_end = {self.tokens[-1].end} |"
|
||||||
|
logger.info(asr_processing_logs)
|
||||||
|
|
||||||
# Process transcription
|
if type(item) is Silence:
|
||||||
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
|
self.online.insert_silence(item.duration, self.tokens[-1].end)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(item, np.ndarray):
|
||||||
|
pcm_array = item
|
||||||
|
else:
|
||||||
|
raise Exception('item should be pcm_array')
|
||||||
|
|
||||||
|
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||||
|
|
||||||
@@ -279,8 +332,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
if new_tokens:
|
if new_tokens:
|
||||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||||
self.full_transcription += validated_text
|
|
||||||
|
|
||||||
if buffer_text.startswith(validated_text):
|
if buffer_text.startswith(validated_text):
|
||||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||||
|
|
||||||
@@ -297,7 +348,7 @@ class AudioProcessor:
|
|||||||
new_end_buffer = max(candidate_end_times)
|
new_end_buffer = max(candidate_end_times)
|
||||||
|
|
||||||
await self.update_transcription(
|
await self.update_transcription(
|
||||||
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
|
new_tokens, buffer_text, new_end_buffer, self.sep
|
||||||
)
|
)
|
||||||
self.transcription_queue.task_done()
|
self.transcription_queue.task_done()
|
||||||
|
|
||||||
@@ -312,25 +363,35 @@ class AudioProcessor:
|
|||||||
async def diarization_processor(self, diarization_obj):
|
async def diarization_processor(self, diarization_obj):
|
||||||
"""Process audio chunks for speaker diarization."""
|
"""Process audio chunks for speaker diarization."""
|
||||||
buffer_diarization = ""
|
buffer_diarization = ""
|
||||||
|
cumulative_pcm_duration_stream_time = 0.0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
pcm_array = await self.diarization_queue.get()
|
item = await self.diarization_queue.get()
|
||||||
if pcm_array is SENTINEL:
|
if item is SENTINEL:
|
||||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||||
self.diarization_queue.task_done()
|
self.diarization_queue.task_done()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if type(item) is Silence:
|
||||||
|
cumulative_pcm_duration_stream_time += item.duration
|
||||||
|
diarization_obj.insert_silence(item.duration)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(item, np.ndarray):
|
||||||
|
pcm_array = item
|
||||||
|
else:
|
||||||
|
raise Exception('item should be pcm_array')
|
||||||
|
|
||||||
# Process diarization
|
# Process diarization
|
||||||
await diarization_obj.diarize(pcm_array)
|
await diarization_obj.diarize(pcm_array)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||||
self.end_attributed_speaker,
|
|
||||||
self.tokens,
|
self.tokens,
|
||||||
use_punctuation_split=self.args.punctuation_split
|
use_punctuation_split=self.args.punctuation_split
|
||||||
)
|
)
|
||||||
self.end_attributed_speaker = new_end
|
if len(self.tokens) > 0:
|
||||||
|
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||||
if buffer_diarization:
|
if buffer_diarization:
|
||||||
self.buffer_diarization = buffer_diarization
|
self.buffer_diarization = buffer_diarization
|
||||||
|
|
||||||
@@ -346,6 +407,8 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def results_formatter(self):
|
async def results_formatter(self):
|
||||||
"""Format processing results for output."""
|
"""Format processing results for output."""
|
||||||
|
last_sent_trans = None
|
||||||
|
last_sent_diar = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||||
@@ -383,13 +446,16 @@ class AudioProcessor:
|
|||||||
lines = []
|
lines = []
|
||||||
last_end_diarized = 0
|
last_end_diarized = 0
|
||||||
undiarized_text = []
|
undiarized_text = []
|
||||||
|
current_time = time() - self.beg_loop if self.beg_loop else None
|
||||||
# Process each token
|
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence)
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
speaker = token.speaker
|
speaker = token.speaker
|
||||||
|
|
||||||
|
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||||
|
speaker = 1
|
||||||
|
|
||||||
# Handle diarization
|
# Handle diarization
|
||||||
if self.args.diarization:
|
if self.args.diarization and not tokens[-1].speaker == -2:
|
||||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||||
undiarized_text.append(token.text)
|
undiarized_text.append(token.text)
|
||||||
continue
|
continue
|
||||||
@@ -398,21 +464,23 @@ class AudioProcessor:
|
|||||||
if speaker not in [-1, 0]:
|
if speaker not in [-1, 0]:
|
||||||
last_end_diarized = max(token.end, last_end_diarized)
|
last_end_diarized = max(token.end, last_end_diarized)
|
||||||
|
|
||||||
# Group by speaker
|
debug_info = ""
|
||||||
|
if self.debug:
|
||||||
|
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||||
if speaker != previous_speaker or not lines:
|
if speaker != previous_speaker or not lines:
|
||||||
lines.append({
|
lines.append({
|
||||||
"speaker": speaker,
|
"speaker": speaker,
|
||||||
"text": token.text,
|
"text": token.text + debug_info,
|
||||||
"beg": format_time(token.start),
|
"beg": format_time(token.start),
|
||||||
"end": format_time(token.end),
|
"end": format_time(token.end),
|
||||||
"diff": round(token.end - last_end_diarized, 2)
|
"diff": round(token.end - last_end_diarized, 2)
|
||||||
})
|
})
|
||||||
previous_speaker = speaker
|
previous_speaker = speaker
|
||||||
elif token.text: # Only append if text isn't empty
|
elif token.text: # Only append if text isn't empty
|
||||||
lines[-1]["text"] += sep + token.text
|
lines[-1]["text"] += sep + token.text + debug_info
|
||||||
lines[-1]["end"] = format_time(token.end)
|
lines[-1]["end"] = format_time(token.end)
|
||||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||||
|
|
||||||
# Handle undiarized text
|
# Handle undiarized text
|
||||||
if undiarized_text:
|
if undiarized_text:
|
||||||
combined = sep.join(undiarized_text)
|
combined = sep.join(undiarized_text)
|
||||||
@@ -449,10 +517,19 @@ class AudioProcessor:
|
|||||||
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||||
f" | {buffer_transcription} | {buffer_diarization}"
|
f" | {buffer_transcription} | {buffer_diarization}"
|
||||||
|
|
||||||
if current_response_signature != self.last_response_content and \
|
trans = state["remaining_time_transcription"]
|
||||||
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
diar = state["remaining_time_diarization"]
|
||||||
|
should_push = (
|
||||||
|
current_response_signature != self.last_response_content
|
||||||
|
or last_sent_trans is None
|
||||||
|
or round(trans, 1) != round(last_sent_trans, 1)
|
||||||
|
or round(diar, 1) != round(last_sent_diar, 1)
|
||||||
|
)
|
||||||
|
if should_push and (final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0):
|
||||||
yield response
|
yield response
|
||||||
self.last_response_content = current_response_signature
|
self.last_response_content = current_response_signature
|
||||||
|
last_sent_trans = trans
|
||||||
|
last_sent_diar = diar
|
||||||
|
|
||||||
# Check for termination condition
|
# Check for termination condition
|
||||||
if self.is_stopping:
|
if self.is_stopping:
|
||||||
@@ -564,6 +641,10 @@ class AudioProcessor:
|
|||||||
|
|
||||||
async def process_audio(self, message):
|
async def process_audio(self, message):
|
||||||
"""Process incoming audio data."""
|
"""Process incoming audio data."""
|
||||||
|
|
||||||
|
if not self.beg_loop:
|
||||||
|
self.beg_loop = time()
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
logger.info("Empty audio message received, initiating stop sequence.")
|
logger.info("Empty audio message received, initiating stop sequence.")
|
||||||
self.is_stopping = True
|
self.is_stopping = True
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
import pathlib
|
||||||
|
import whisperlivekit.web as webpkg
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger().setLevel(logging.WARNING)
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
@@ -30,6 +33,8 @@ app.add_middleware(
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||||
|
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
@@ -47,7 +52,7 @@ async def handle_websocket_results(websocket, results_generator):
|
|||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
logger.error(f"Error in WebSocket results handler: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
try:
|
try:
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||||
|
from whisperlivekit.warmup import warmup_asr, warmup_online
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
import sys
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -22,7 +25,6 @@ class TranscriptionEngine:
|
|||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
"warmup_file": None,
|
"warmup_file": None,
|
||||||
"confidence_validation": False,
|
|
||||||
"diarization": False,
|
"diarization": False,
|
||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
"min_chunk_size": 0.5,
|
"min_chunk_size": 0.5,
|
||||||
@@ -32,22 +34,22 @@ class TranscriptionEngine:
|
|||||||
"lan": "auto",
|
"lan": "auto",
|
||||||
"task": "transcribe",
|
"task": "transcribe",
|
||||||
"backend": "faster-whisper",
|
"backend": "faster-whisper",
|
||||||
"vac": False,
|
"vac": True,
|
||||||
"vac_chunk_size": 0.04,
|
"vac_chunk_size": 0.04,
|
||||||
"buffer_trimming": "segment",
|
|
||||||
"buffer_trimming_sec": 15,
|
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
"ssl_keyfile": None,
|
"ssl_keyfile": None,
|
||||||
"transcription": True,
|
"transcription": True,
|
||||||
"vad": True,
|
"vad": True,
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
# whisperstreaming params:
|
||||||
"embedding_model": "pyannote/embedding",
|
"buffer_trimming": "segment",
|
||||||
|
"confidence_validation": False,
|
||||||
|
"buffer_trimming_sec": 15,
|
||||||
# simulstreaming params:
|
# simulstreaming params:
|
||||||
"frame_threshold": 25,
|
"frame_threshold": 25,
|
||||||
"beams": 1,
|
"beams": 1,
|
||||||
"decoder_type": None,
|
"decoder_type": None,
|
||||||
"audio_max_len": 30.0,
|
"audio_max_len": 20.0,
|
||||||
"audio_min_len": 0.0,
|
"audio_min_len": 0.0,
|
||||||
"cif_ckpt_path": None,
|
"cif_ckpt_path": None,
|
||||||
"never_fire": False,
|
"never_fire": False,
|
||||||
@@ -55,6 +57,10 @@ class TranscriptionEngine:
|
|||||||
"static_init_prompt": None,
|
"static_init_prompt": None,
|
||||||
"max_context_tokens": None,
|
"max_context_tokens": None,
|
||||||
"model_path": './base.pt',
|
"model_path": './base.pt',
|
||||||
|
"diarization_backend": "diart",
|
||||||
|
# diart params:
|
||||||
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
|
"embedding_model": "pyannote/embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
config_dict = {**defaults, **kwargs}
|
||||||
@@ -63,6 +69,8 @@ class TranscriptionEngine:
|
|||||||
config_dict['transcription'] = not kwargs['no_transcription']
|
config_dict['transcription'] = not kwargs['no_transcription']
|
||||||
if 'no_vad' in kwargs:
|
if 'no_vad' in kwargs:
|
||||||
config_dict['vad'] = not kwargs['no_vad']
|
config_dict['vad'] = not kwargs['no_vad']
|
||||||
|
if 'no_vac' in kwargs:
|
||||||
|
config_dict['vac'] = not kwargs['no_vac']
|
||||||
|
|
||||||
config_dict.pop('no_transcription', None)
|
config_dict.pop('no_transcription', None)
|
||||||
config_dict.pop('no_vad', None)
|
config_dict.pop('no_vad', None)
|
||||||
@@ -76,17 +84,72 @@ class TranscriptionEngine:
|
|||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.diarization = None
|
self.diarization = None
|
||||||
|
self.vac_model = None
|
||||||
|
|
||||||
|
if self.args.vac:
|
||||||
|
import torch
|
||||||
|
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.asr, self.tokenizer = backend_factory(self.args)
|
if self.args.backend == "simulstreaming":
|
||||||
warmup_asr(self.asr, self.args.warmup_file)
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||||
|
self.tokenizer = None
|
||||||
|
simulstreaming_kwargs = {}
|
||||||
|
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||||
|
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||||
|
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
|
||||||
|
if hasattr(self.args, attr):
|
||||||
|
simulstreaming_kwargs[attr] = getattr(self.args, attr)
|
||||||
|
|
||||||
|
# Add segment_length from min_chunk_size
|
||||||
|
simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5)
|
||||||
|
simulstreaming_kwargs['task'] = self.args.task
|
||||||
|
|
||||||
|
size = self.args.model
|
||||||
|
self.asr = SimulStreamingASR(
|
||||||
|
modelsize=size,
|
||||||
|
lan=self.args.lan,
|
||||||
|
cache_dir=getattr(self.args, 'model_cache_dir', None),
|
||||||
|
model_dir=getattr(self.args, 'model_dir', None),
|
||||||
|
**simulstreaming_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.asr, self.tokenizer = backend_factory(self.args)
|
||||||
|
warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
if self.args.diarization_backend == "diart":
|
||||||
self.diarization = DiartDiarization(
|
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||||
block_duration=self.args.min_chunk_size,
|
self.diarization = DiartDiarization(
|
||||||
segmentation_model_name=self.args.segmentation_model,
|
block_duration=self.args.min_chunk_size,
|
||||||
embedding_model_name=self.args.embedding_model
|
segmentation_model_name=self.args.segmentation_model,
|
||||||
)
|
embedding_model_name=self.args.embedding_model
|
||||||
|
)
|
||||||
|
elif self.args.diarization_backend == "sortformer":
|
||||||
|
raise ValueError('Sortformer backend in developement')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||||
|
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||||
|
if args.backend == "simulstreaming":
|
||||||
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
|
online = SimulStreamingOnlineProcessor(
|
||||||
|
asr,
|
||||||
|
logfile=logfile,
|
||||||
|
)
|
||||||
|
# warmup_online(online, args.warmup_file)
|
||||||
|
else:
|
||||||
|
online = OnlineASRProcessor(
|
||||||
|
asr,
|
||||||
|
tokenizer,
|
||||||
|
logfile=logfile,
|
||||||
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||||
|
confidence_validation = args.confidence_validation
|
||||||
|
)
|
||||||
|
return online
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
|
|||||||
self.speaker_segments = []
|
self.speaker_segments = []
|
||||||
self.processed_time = 0
|
self.processed_time = 0
|
||||||
self.segment_lock = threading.Lock()
|
self.segment_lock = threading.Lock()
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
|
||||||
def on_next(self, value: Tuple[Annotation, Any]):
|
def on_next(self, value: Tuple[Annotation, Any]):
|
||||||
annotation, audio = value
|
annotation, audio = value
|
||||||
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
|
|||||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||||
self.speaker_segments.append(SpeakerSegment(
|
self.speaker_segments.append(SpeakerSegment(
|
||||||
speaker=speaker,
|
speaker=speaker,
|
||||||
start=start,
|
start=start + self.global_time_offset,
|
||||||
end=end
|
end=end + self.global_time_offset
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
logger.debug("\nNo speakers detected in this segment")
|
logger.debug("\nNo speakers detected in this segment")
|
||||||
@@ -165,7 +166,7 @@ class WebSocketAudioSource(AudioSource):
|
|||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
@@ -199,6 +200,9 @@ class DiartDiarization:
|
|||||||
self.inference.attach_observers(self.observer)
|
self.inference.attach_observers(self.observer)
|
||||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration):
|
||||||
|
self.observer.global_time_offset += silence_duration
|
||||||
|
|
||||||
async def diarize(self, pcm_array: np.ndarray):
|
async def diarize(self, pcm_array: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Process audio data for diarization.
|
Process audio data for diarization.
|
||||||
@@ -206,15 +210,14 @@ class DiartDiarization:
|
|||||||
"""
|
"""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
self.observer.clear_old_segments()
|
# self.observer.clear_old_segments()
|
||||||
return self.observer.get_segments()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
Uses the segments collected by the observer.
|
Uses the segments collected by the observer.
|
||||||
@@ -231,85 +234,82 @@ class DiartDiarization:
|
|||||||
|
|
||||||
if not self.lag_diart and segments and tokens:
|
if not self.lag_diart and segments and tokens:
|
||||||
self.lag_diart = segments[0].start - tokens[0].start
|
self.lag_diart = segments[0].start - tokens[0].start
|
||||||
for token in tokens:
|
|
||||||
for segment in segments:
|
|
||||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
|
||||||
token.speaker = extract_number(segment.speaker) + 1
|
|
||||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
|
||||||
|
|
||||||
if use_punctuation_split and len(tokens) > 1:
|
if not use_punctuation_split:
|
||||||
punctuation_marks = {'.', '!', '?'}
|
for token in tokens:
|
||||||
|
for segment in segments:
|
||||||
print("Here are the tokens:",
|
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
|
else:
|
||||||
segment_map = []
|
tokens = add_speaker_to_tokens(segments, tokens)
|
||||||
for segment in segments:
|
return tokens
|
||||||
speaker_num = extract_number(segment.speaker) + 1
|
|
||||||
segment_map.append((segment.start, segment.end, speaker_num))
|
|
||||||
segment_map.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
current_token = tokens[i]
|
|
||||||
|
|
||||||
is_sentence_end = False
|
|
||||||
if current_token.text and current_token.text.strip():
|
|
||||||
text = current_token.text.strip()
|
|
||||||
if text[-1] in punctuation_marks:
|
|
||||||
is_sentence_end = True
|
|
||||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
|
||||||
|
|
||||||
if is_sentence_end and current_token.speaker != -1:
|
|
||||||
punctuation_time = current_token.end
|
|
||||||
current_speaker = current_token.speaker
|
|
||||||
|
|
||||||
j = i + 1
|
|
||||||
next_sentence_tokens = []
|
|
||||||
while j < len(tokens):
|
|
||||||
next_token = tokens[j]
|
|
||||||
next_sentence_tokens.append(j)
|
|
||||||
|
|
||||||
# Check if this token ends the next sentence
|
|
||||||
if next_token.text and next_token.text.strip():
|
|
||||||
if next_token.text.strip()[-1] in punctuation_marks:
|
|
||||||
break
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
if next_sentence_tokens:
|
|
||||||
speaker_times = {}
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
token = tokens[idx]
|
|
||||||
# Find which segments overlap with this token
|
|
||||||
for seg_start, seg_end, seg_speaker in segment_map:
|
|
||||||
if not (seg_end <= token.start or seg_start >= token.end):
|
|
||||||
# Calculate overlap duration
|
|
||||||
overlap_start = max(seg_start, token.start)
|
|
||||||
overlap_end = min(seg_end, token.end)
|
|
||||||
overlap_duration = overlap_end - overlap_start
|
|
||||||
|
|
||||||
if seg_speaker not in speaker_times:
|
|
||||||
speaker_times[seg_speaker] = 0
|
|
||||||
speaker_times[seg_speaker] += overlap_duration
|
|
||||||
|
|
||||||
if speaker_times:
|
|
||||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
|
||||||
|
|
||||||
if dominant_speaker != current_speaker:
|
|
||||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker != dominant_speaker:
|
|
||||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
|
||||||
tokens[idx].speaker = dominant_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
else:
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker == -1:
|
|
||||||
tokens[idx].speaker = current_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return end_attributed_speaker
|
def concatenate_speakers(segments):
|
||||||
|
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||||
|
for segment in segments:
|
||||||
|
speaker = extract_number(segment.speaker) + 1
|
||||||
|
if segments_concatenated[-1]['speaker'] != speaker:
|
||||||
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||||
|
else:
|
||||||
|
segments_concatenated[-1]['end'] = segment.end
|
||||||
|
# print("Segments concatenated:")
|
||||||
|
# for entry in segments_concatenated:
|
||||||
|
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||||
|
return segments_concatenated
|
||||||
|
|
||||||
|
|
||||||
|
def add_speaker_to_tokens(segments, tokens):
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||||
|
"""
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||||
|
segments_concatenated = concatenate_speakers(segments)
|
||||||
|
for ind, segment in enumerate(segments_concatenated):
|
||||||
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||||
|
if punctuation_token.start > segment['end']:
|
||||||
|
after_length = punctuation_token.start - segment['end']
|
||||||
|
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||||
|
if before_length > after_length:
|
||||||
|
segment['end'] = punctuation_token.start
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||||
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||||
|
else:
|
||||||
|
segment['end'] = punctuation_tokens[i - 1].end
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||||
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||||
|
break
|
||||||
|
|
||||||
|
last_end = 0.0
|
||||||
|
for token in tokens:
|
||||||
|
start = max(last_end + 0.01, token.start)
|
||||||
|
token.start = start
|
||||||
|
token.end = max(start, token.end)
|
||||||
|
last_end = token.end
|
||||||
|
|
||||||
|
ind_last_speaker = 0
|
||||||
|
for segment in segments_concatenated:
|
||||||
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||||
|
if token.end <= segment['end']:
|
||||||
|
token.speaker = segment['speaker']
|
||||||
|
ind_last_speaker = i + 1
|
||||||
|
# print(
|
||||||
|
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||||
|
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||||
|
# )
|
||||||
|
elif token.start > segment['end']:
|
||||||
|
break
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_tokens(tokens):
|
||||||
|
conversation = [{"speaker": -1, "text": ""}]
|
||||||
|
for token in tokens:
|
||||||
|
speaker = conversation[-1]['speaker']
|
||||||
|
if token.speaker != speaker:
|
||||||
|
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||||
|
else:
|
||||||
|
conversation[-1]['text'] += token.text
|
||||||
|
print("Conversation:")
|
||||||
|
for entry in conversation:
|
||||||
|
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||||
145
whisperlivekit/diarization/sortformer_backend.py
Normal file
145
whisperlivekit/diarization/sortformer_backend.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||||
|
except ImportError:
|
||||||
|
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||||
|
|
||||||
|
class SortformerDiarization:
|
||||||
|
def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||||
|
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||||
|
self.diar_model.eval()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.diar_model.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
# Streaming parameters for speed
|
||||||
|
self.diar_model.sortformer_modules.chunk_len = 12
|
||||||
|
self.diar_model.sortformer_modules.chunk_right_context = 1
|
||||||
|
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||||
|
self.diar_model.sortformer_modules.fifo_len = 188
|
||||||
|
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
|
self.diar_model.sortformer_modules.log = False
|
||||||
|
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||||
|
|
||||||
|
self.batch_size = 1
|
||||||
|
self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device)
|
||||||
|
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.speaker_segments = []
|
||||||
|
|
||||||
|
self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
async_streaming=True,
|
||||||
|
device=self.diar_model.device
|
||||||
|
)
|
||||||
|
self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_audio_signal(self, signal):
|
||||||
|
audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device)
|
||||||
|
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device)
|
||||||
|
processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)
|
||||||
|
return processed_signal, processed_signal_length
|
||||||
|
|
||||||
|
def _create_streaming_loader(self, processed_signal, processed_signal_length):
|
||||||
|
streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader(
|
||||||
|
feat_seq=processed_signal,
|
||||||
|
feat_seq_length=processed_signal_length,
|
||||||
|
feat_seq_offset=self.processed_signal_offset,
|
||||||
|
)
|
||||||
|
return streaming_loader
|
||||||
|
|
||||||
|
async def diarize(self, pcm_array: np.ndarray):
|
||||||
|
"""
|
||||||
|
Process an incoming audio chunk for diarization.
|
||||||
|
"""
|
||||||
|
self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array])
|
||||||
|
|
||||||
|
# Process in fixed-size chunks (e.g., 1 second)
|
||||||
|
chunk_size = self.sample_rate # 1 second of audio
|
||||||
|
|
||||||
|
while len(self.audio_buffer) >= chunk_size:
|
||||||
|
chunk_to_process = self.audio_buffer[:chunk_size]
|
||||||
|
self.audio_buffer = self.audio_buffer[chunk_size:]
|
||||||
|
|
||||||
|
processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process)
|
||||||
|
|
||||||
|
current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride
|
||||||
|
|
||||||
|
streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length)
|
||||||
|
|
||||||
|
frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride
|
||||||
|
chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s
|
||||||
|
|
||||||
|
for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||||
|
processed_signal=chunk_feat_seq_t,
|
||||||
|
processed_signal_length=feat_lengths,
|
||||||
|
streaming_state=self.streaming_state,
|
||||||
|
total_preds=self.total_preds,
|
||||||
|
left_offset=left_offset,
|
||||||
|
right_offset=right_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_new_frames = feat_lengths[0].item()
|
||||||
|
|
||||||
|
# Get predictions for the current chunk from the end of total_preds
|
||||||
|
preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy()
|
||||||
|
active_speakers = np.argmax(preds_np, axis=1)
|
||||||
|
|
||||||
|
for idx, spk in enumerate(active_speakers):
|
||||||
|
start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s)
|
||||||
|
end_time = start_time + frame_duration_s
|
||||||
|
|
||||||
|
if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1:
|
||||||
|
self.speaker_segments[-1].end = end_time
|
||||||
|
else:
|
||||||
|
self.speaker_segments.append(SpeakerSegment(
|
||||||
|
speaker=int(spk + 1),
|
||||||
|
start=start_time,
|
||||||
|
end=end_time
|
||||||
|
))
|
||||||
|
|
||||||
|
self.processed_signal_offset += processed_signal_length
|
||||||
|
|
||||||
|
|
||||||
|
def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list:
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
|
"""
|
||||||
|
for token in tokens:
|
||||||
|
for segment in self.speaker_segments:
|
||||||
|
if not (segment.end <= token.start or segment.start >= token.end):
|
||||||
|
token.speaker = segment.speaker
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Cleanup resources.
|
||||||
|
"""
|
||||||
|
logger.info("Closing SortformerDiarization.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import librosa
|
||||||
|
an4_audio = 'new_audio_test.mp3'
|
||||||
|
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||||
|
|
||||||
|
diarization_pipeline = SortformerDiarization()
|
||||||
|
|
||||||
|
# Simulate streaming
|
||||||
|
chunk_size = 16000 # 1 second
|
||||||
|
for i in range(0, len(signal), chunk_size):
|
||||||
|
chunk = signal[i:i+chunk_size]
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(diarization_pipeline.diarize(chunk))
|
||||||
|
|
||||||
|
for segment in diarization_pipeline.speaker_segments:
|
||||||
|
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||||
257
whisperlivekit/diarization/sortformer_backend_2.py
Normal file
257
whisperlivekit/diarization/sortformer_backend_2.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||||
|
except ImportError:
|
||||||
|
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||||
|
|
||||||
|
|
||||||
|
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||||
|
diar_model.eval()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
diar_model.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
|
||||||
|
# diar_model.sortformer_modules.chunk_len = 6
|
||||||
|
# diar_model.sortformer_modules.spkcache_len = 188
|
||||||
|
# diar_model.sortformer_modules.chunk_right_context = 7
|
||||||
|
# diar_model.sortformer_modules.fifo_len = 188
|
||||||
|
# diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
|
# diar_model.sortformer_modules.log = False
|
||||||
|
|
||||||
|
|
||||||
|
# here we change the settings for our goal: speed!
|
||||||
|
# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12.
|
||||||
|
diar_model.sortformer_modules.chunk_len = 12
|
||||||
|
|
||||||
|
# for more speed, we reduce the 'right context'. it's like looking less into the future.
|
||||||
|
diar_model.sortformer_modules.chunk_right_context = 1
|
||||||
|
|
||||||
|
# we keep the rest same for now
|
||||||
|
diar_model.sortformer_modules.spkcache_len = 188
|
||||||
|
diar_model.sortformer_modules.fifo_len = 188
|
||||||
|
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
|
diar_model.sortformer_modules.log = False
|
||||||
|
diar_model.sortformer_modules._check_streaming_parameters()
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)
|
||||||
|
|
||||||
|
# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
|
||||||
|
# from nemo.collections.asr.modules.audio_preprocessing import get_features
|
||||||
|
from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_audio_signal(signal):
|
||||||
|
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
|
||||||
|
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
|
||||||
|
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
|
||||||
|
window_size= 0.025,
|
||||||
|
normalize="NA",
|
||||||
|
n_fft=512,
|
||||||
|
features=128).get_features(audio_signal, audio_signal_length)
|
||||||
|
return processed_signal, processed_signal_length
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_feat_loader(
|
||||||
|
feat_seq, feat_seq_length, feat_seq_offset
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load a chunk of feature sequence for streaming inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feat_seq (torch.Tensor): Tensor containing feature sequence
|
||||||
|
Shape: (batch_size, feat_dim, feat frame count)
|
||||||
|
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
|
||||||
|
Shape: (batch_size,)
|
||||||
|
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
|
||||||
|
Shape: (batch_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_idx (int): Index of the current chunk
|
||||||
|
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
|
||||||
|
Shape: (batch_size, diar frame count, feat_dim)
|
||||||
|
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
|
||||||
|
Shape: (batch_size,)
|
||||||
|
"""
|
||||||
|
feat_len = feat_seq.shape[2]
|
||||||
|
num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor))
|
||||||
|
if False:
|
||||||
|
logging.info(
|
||||||
|
f"feat_len={feat_len}, num_chunks={num_chunks}, "
|
||||||
|
f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}"
|
||||||
|
)
|
||||||
|
|
||||||
|
stt_feat, end_feat, chunk_idx = 0, 0, 0
|
||||||
|
while end_feat < feat_len:
|
||||||
|
left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat)
|
||||||
|
end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len)
|
||||||
|
right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat)
|
||||||
|
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
|
||||||
|
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
|
||||||
|
0, chunk_feat_seq.shape[2]
|
||||||
|
)
|
||||||
|
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
|
||||||
|
stt_feat = end_feat
|
||||||
|
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
|
||||||
|
if False:
|
||||||
|
logging.info(
|
||||||
|
f"chunk_idx: {chunk_idx}, "
|
||||||
|
f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
|
||||||
|
f"chunk_feat_lengths: {feat_lengths}"
|
||||||
|
)
|
||||||
|
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingSortformerState:
|
||||||
|
"""
|
||||||
|
This class creates a class instance that will be used to store the state of the
|
||||||
|
streaming Sortformer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||||
|
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||||
|
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||||
|
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||||
|
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||||
|
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||||
|
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||||
|
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||||
|
n_sil_frames (torch.Tensor): Number of silence frames
|
||||||
|
"""
|
||||||
|
|
||||||
|
spkcache = None # Speaker cache to store embeddings from start
|
||||||
|
spkcache_lengths = None #
|
||||||
|
spkcache_preds = None # speaker cache predictions
|
||||||
|
fifo = None # to save the embedding from the latest chunks
|
||||||
|
fifo_lengths = None
|
||||||
|
fifo_preds = None
|
||||||
|
spk_perm = None
|
||||||
|
mean_sil_emb = None
|
||||||
|
n_sil_frames = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||||
|
"""
|
||||||
|
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size for tensors in streaming state
|
||||||
|
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||||
|
device (torch.device): Device for tensors in streaming state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
streaming_state (SortformerStreamingState): initialized streaming state
|
||||||
|
"""
|
||||||
|
streaming_state = StreamingSortformerState()
|
||||||
|
if async_streaming:
|
||||||
|
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||||
|
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||||
|
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||||
|
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
else:
|
||||||
|
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||||
|
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||||
|
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||||
|
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
return streaming_state
|
||||||
|
|
||||||
|
def process_diarization(signal, chunks):
|
||||||
|
|
||||||
|
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
|
||||||
|
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
|
||||||
|
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
|
||||||
|
window_size= 0.025,
|
||||||
|
normalize="NA",
|
||||||
|
n_fft=512,
|
||||||
|
features=128).get_features(audio_signal, audio_signal_length)
|
||||||
|
|
||||||
|
|
||||||
|
streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset)
|
||||||
|
|
||||||
|
|
||||||
|
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||||
|
batch_size = batch_size,
|
||||||
|
async_streaming = True,
|
||||||
|
device = diar_model.device
|
||||||
|
)
|
||||||
|
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||||
|
|
||||||
|
|
||||||
|
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||||
|
print(f"Chunk duration: {chunk_duration_seconds} seconds")
|
||||||
|
|
||||||
|
l_speakers = [
|
||||||
|
{'start_time': 0,
|
||||||
|
'end_time': 0,
|
||||||
|
'speaker': 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
len_prediction = None
|
||||||
|
left_offset = 0
|
||||||
|
right_offset = 8
|
||||||
|
for i, chunk_feat_seq_t, _, _, _ in streaming_loader:
|
||||||
|
with torch.inference_mode():
|
||||||
|
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||||
|
processed_signal=chunk_feat_seq_t,
|
||||||
|
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||||
|
streaming_state=streaming_state,
|
||||||
|
total_preds=total_preds,
|
||||||
|
left_offset=left_offset,
|
||||||
|
right_offset=right_offset,
|
||||||
|
)
|
||||||
|
left_offset = 8
|
||||||
|
preds_np = total_preds[0].cpu().numpy()
|
||||||
|
active_speakers = np.argmax(preds_np, axis=1)
|
||||||
|
if len_prediction is None:
|
||||||
|
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||||
|
frame_duration = chunk_duration_seconds / len_prediction
|
||||||
|
active_speakers = active_speakers[-len_prediction:]
|
||||||
|
print(chunk_feat_seq_t.shape, total_preds.shape)
|
||||||
|
for idx, spk in enumerate(active_speakers):
|
||||||
|
if spk != l_speakers[-1]['speaker']:
|
||||||
|
l_speakers.append(
|
||||||
|
{'start_time': i * chunk_duration_seconds + idx * frame_duration,
|
||||||
|
'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration,
|
||||||
|
'speaker': spk
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||||
|
|
||||||
|
print(l_speakers)
|
||||||
|
"""
|
||||||
|
Should print
|
||||||
|
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||||
|
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||||
|
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||||
|
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import librosa
|
||||||
|
an4_audio = 'new_audio_test.mp3'
|
||||||
|
signal, sr = librosa.load(an4_audio,sr=16000)
|
||||||
|
|
||||||
|
"""
|
||||||
|
ground truth:
|
||||||
|
speaker 0 : 0:00 - 0:09
|
||||||
|
speaker 1 : 0:09 - 0:19
|
||||||
|
speaker 2 : 0:19 - 0:25
|
||||||
|
speaker 0 : 0:25 - end
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Simulate streaming
|
||||||
|
chunk_size = 16000 # 1 second
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(signal), chunk_size):
|
||||||
|
chunk = signal[i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
process_diarization(signal, chunks)
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
def get_module_path():
|
|
||||||
return os.path.dirname(inspect.getfile(inspect.currentframe()))
|
|
||||||
|
|
||||||
GITHUB_API_URL = "https://api.github.com/repos/ufal/SimulStreaming/contents/simul_whisper/whisper"
|
|
||||||
RAW_BASE_URL = "https://raw.githubusercontent.com/ufal/SimulStreaming/main/simul_whisper/whisper"
|
|
||||||
TARGET_DIR = os.path.join(get_module_path(), "simul_whisper", "whisper")
|
|
||||||
|
|
||||||
def download_files_from_github(api_url, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
response = requests.get(api_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
items = response.json()
|
|
||||||
for item in items:
|
|
||||||
if item['type'] == 'file':
|
|
||||||
download_url = item['download_url']
|
|
||||||
file_name = item['name']
|
|
||||||
file_response = requests.get(download_url)
|
|
||||||
file_response.raise_for_status()
|
|
||||||
with open(os.path.join(local_dir, file_name), 'wb') as f:
|
|
||||||
f.write(file_response.content)
|
|
||||||
elif item['type'] == 'dir':
|
|
||||||
# Recursive call for subdirectories
|
|
||||||
download_files_from_github(item['url'], os.path.join(local_dir, item['name']))
|
|
||||||
|
|
||||||
def download_simulstreaming_backend():
|
|
||||||
print(f"Downloading files into {TARGET_DIR} ...")
|
|
||||||
download_files_from_github(GITHUB_API_URL, TARGET_DIR)
|
|
||||||
print("✅ Download of SimulStreaming backend files completed successfully.")
|
|
||||||
@@ -143,7 +143,7 @@ class FFmpegManager:
|
|||||||
try:
|
try:
|
||||||
data = await asyncio.wait_for(
|
data = await asyncio.wait_for(
|
||||||
self.process.stdout.read(size),
|
self.process.stdout.read(size),
|
||||||
timeout=5.0
|
timeout=20.0
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -58,6 +58,14 @@ def parse_args():
|
|||||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
type=str,
|
||||||
|
default="diart",
|
||||||
|
choices=["sortformer", "diart"],
|
||||||
|
help="The diarization backend to use.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-transcription",
|
"--no-transcription",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -74,7 +82,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="tiny",
|
default="small",
|
||||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,15 +115,15 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="faster-whisper",
|
default="simulstreaming",
|
||||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||||
help="Load only this backend for Whisper processing.",
|
help="Load only this backend for Whisper processing.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vac",
|
"--no-vac",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
help="Disable VAC = voice activity controller.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||||
@@ -242,6 +250,14 @@ def parse_args():
|
|||||||
dest="model_path",
|
dest="model_path",
|
||||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--preloaded_model_count",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
dest="preloaded_model_count",
|
||||||
|
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
110
whisperlivekit/remove_silences.py
Normal file
110
whisperlivekit/remove_silences.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
import re
|
||||||
|
|
||||||
|
MIN_SILENCE_DURATION = 4 #in seconds
|
||||||
|
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||||
|
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||||
|
|
||||||
|
def blank_to_silence(tokens):
|
||||||
|
full_string = ''.join([t.text for t in tokens])
|
||||||
|
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||||
|
matches = []
|
||||||
|
for pattern in patterns:
|
||||||
|
for m in pattern.finditer(full_string):
|
||||||
|
matches.append({
|
||||||
|
'start': m.start(),
|
||||||
|
'end': m.end()
|
||||||
|
})
|
||||||
|
if matches:
|
||||||
|
# cleaned = pattern.sub(' ', full_string).strip()
|
||||||
|
# print("Cleaned:", cleaned)
|
||||||
|
cumulated_len = 0
|
||||||
|
silence_token = None
|
||||||
|
cleaned_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if matches:
|
||||||
|
start = cumulated_len
|
||||||
|
end = cumulated_len + len(token.text)
|
||||||
|
cumulated_len = end
|
||||||
|
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||||
|
if silence_token: #previous token was already silence
|
||||||
|
silence_token.start = min(silence_token.start, token.start)
|
||||||
|
silence_token.end = max(silence_token.end, token.end)
|
||||||
|
else: #new silence
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if silence_token: #there was silence but no more
|
||||||
|
if silence_token.end - silence_token.start >= MIN_SILENCE_DURATION:
|
||||||
|
cleaned_tokens.append(
|
||||||
|
silence_token
|
||||||
|
)
|
||||||
|
silence_token = None
|
||||||
|
matches.pop(0)
|
||||||
|
cleaned_tokens.append(token)
|
||||||
|
# print(cleaned_tokens)
|
||||||
|
return cleaned_tokens
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def no_token_to_silence(tokens):
|
||||||
|
new_tokens = []
|
||||||
|
silence_token = None
|
||||||
|
for token in tokens:
|
||||||
|
if token.speaker == -2:
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||||
|
new_tokens[-1].end = token.end
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
|
||||||
|
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||||
|
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2:
|
||||||
|
new_tokens[-1].end = token.start
|
||||||
|
else:
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=last_end,
|
||||||
|
end=token.start,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
new_tokens.append(silence_token)
|
||||||
|
|
||||||
|
if token.speaker != -2:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
|
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||||
|
if not tokens:
|
||||||
|
return [], buffer_transcription, buffer_diarization
|
||||||
|
last_token = tokens[-1]
|
||||||
|
if tokens and (
|
||||||
|
current_time - last_token.end >= END_SILENCE_DURATION
|
||||||
|
or
|
||||||
|
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||||
|
):
|
||||||
|
if last_token.speaker == -2:
|
||||||
|
last_token.end = current_time
|
||||||
|
else:
|
||||||
|
tokens.append(
|
||||||
|
ASRToken(
|
||||||
|
start=tokens[-1].end,
|
||||||
|
end=current_time,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
)
|
||||||
|
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
||||||
|
buffer_diarization = ""
|
||||||
|
return tokens, buffer_transcription, buffer_diarization
|
||||||
|
|
||||||
|
|
||||||
|
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||||
|
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||||
|
tokens = no_token_to_silence(tokens)
|
||||||
|
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
||||||
|
return tokens, buffer_transcription, buffer_diarization
|
||||||
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SimulStreamingASR",
|
||||||
|
"SimulStreamingOnlineProcessor",
|
||||||
|
]
|
||||||
|
|||||||
315
whisperlivekit/simul_whisper/backend.py
Normal file
315
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
import logging
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
from whisperlivekit.warmup import load_file
|
||||||
|
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||||
|
from .whisper import load_model, tokenizer
|
||||||
|
from .whisper.audio import TOKENS_PER_SECOND
|
||||||
|
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||||
|
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"""SimulStreaming dependencies are not available.
|
||||||
|
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
|
||||||
|
|
||||||
|
# TOO_MANY_REPETITIONS = 3
|
||||||
|
|
||||||
|
class SimulStreamingOnlineProcessor:
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr,
|
||||||
|
logfile=sys.stderr,
|
||||||
|
warmup_file=None
|
||||||
|
):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
|
||||||
|
self.committed: List[ASRToken] = []
|
||||||
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
|
self.load_new_backend()
|
||||||
|
if asr.tokenizer:
|
||||||
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
|
def load_new_backend(self):
|
||||||
|
model = self.asr.get_new_model_instance()
|
||||||
|
self.model = PaddedAlignAttWhisper(
|
||||||
|
cfg=self.asr.cfg,
|
||||||
|
loaded_model=model)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration, offset):
|
||||||
|
"""
|
||||||
|
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||||
|
"""
|
||||||
|
if silence_duration < 5:
|
||||||
|
gap_silence = torch.zeros(int(16000*silence_duration))
|
||||||
|
self.model.insert_audio(gap_silence)
|
||||||
|
# self.global_time_offset += silence_duration
|
||||||
|
else:
|
||||||
|
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
self.global_time_offset += silence_duration + offset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
|
||||||
|
# Convert numpy array to torch tensor
|
||||||
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||||
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
return Transcript(
|
||||||
|
start=None,
|
||||||
|
end=None,
|
||||||
|
text='',
|
||||||
|
probability=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def timestamped_text(self, tokens, generation):
|
||||||
|
"""
|
||||||
|
generate timestamped text from tokens and generation data.
|
||||||
|
|
||||||
|
args:
|
||||||
|
tokens: List of tokens to process
|
||||||
|
generation: Dictionary containing generation progress and optionally results
|
||||||
|
|
||||||
|
returns:
|
||||||
|
List of tuples containing (start_time, end_time, word) for each word
|
||||||
|
"""
|
||||||
|
FRAME_DURATION = 0.02
|
||||||
|
if "result" in generation:
|
||||||
|
split_words = generation["result"]["split_words"]
|
||||||
|
split_tokens = generation["result"]["split_tokens"]
|
||||||
|
else:
|
||||||
|
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
||||||
|
progress = generation["progress"]
|
||||||
|
frames = [p["most_attended_frames"][0] for p in progress]
|
||||||
|
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
|
||||||
|
tokens_queue = tokens.copy()
|
||||||
|
timestamped_words = []
|
||||||
|
|
||||||
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
# start_frame = None
|
||||||
|
# end_frame = None
|
||||||
|
for expected_token in word_tokens:
|
||||||
|
if not tokens_queue or not frames:
|
||||||
|
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
|
||||||
|
|
||||||
|
actual_token = tokens_queue.pop(0)
|
||||||
|
current_frame = frames.pop(0)
|
||||||
|
current_timestamp = absolute_timestamps.pop(0)
|
||||||
|
if actual_token != expected_token:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token mismatch: expected '{expected_token}', "
|
||||||
|
f"got '{actual_token}' at frame {current_frame}"
|
||||||
|
)
|
||||||
|
# if start_frame is None:
|
||||||
|
# start_frame = current_frame
|
||||||
|
# end_frame = current_frame
|
||||||
|
# start_time = start_frame * FRAME_DURATION
|
||||||
|
# end_time = end_frame * FRAME_DURATION
|
||||||
|
start_time = current_timestamp
|
||||||
|
end_time = current_timestamp + 0.1
|
||||||
|
timestamp_entry = (start_time, end_time, word)
|
||||||
|
timestamped_words.append(timestamp_entry)
|
||||||
|
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
|
||||||
|
return timestamped_words
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""
|
||||||
|
Process accumulated audio chunks using SimulStreaming.
|
||||||
|
|
||||||
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tokens, generation_progress = self.model.infer(is_last=is_last)
|
||||||
|
ts_words = self.timestamped_text(tokens, generation_progress)
|
||||||
|
|
||||||
|
new_tokens = []
|
||||||
|
for ts_word in ts_words:
|
||||||
|
|
||||||
|
start, end, word = ts_word
|
||||||
|
token = ASRToken(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=word,
|
||||||
|
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||||
|
).with_offset(
|
||||||
|
self.global_time_offset
|
||||||
|
)
|
||||||
|
new_tokens.append(token)
|
||||||
|
|
||||||
|
# identical_tokens = 0
|
||||||
|
# n_new_tokens = len(new_tokens)
|
||||||
|
# if n_new_tokens:
|
||||||
|
|
||||||
|
self.committed.extend(new_tokens)
|
||||||
|
|
||||||
|
# if token in self.committed:
|
||||||
|
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
|
||||||
|
# if pos:
|
||||||
|
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
|
||||||
|
# commited_segment = self.committed[i:i+n_new_tokens]
|
||||||
|
# if commited_segment == new_tokens:
|
||||||
|
# identical_segments +=1
|
||||||
|
# if identical_tokens >= TOO_MANY_REPETITIONS:
|
||||||
|
# logger.warning('Too many repetition, model is stuck. Load a new one')
|
||||||
|
# self.committed = self.committed[:i]
|
||||||
|
# self.load_new_backend()
|
||||||
|
# return [], self.end
|
||||||
|
|
||||||
|
# pos = self.committed.rindex(token)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return new_tokens, self.end
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming processing error: {e}")
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
"""Warmup the SimulStreaming model."""
|
||||||
|
try:
|
||||||
|
self.model.insert_audio(audio)
|
||||||
|
self.model.infer(True)
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
logger.info("SimulStreaming model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# free the model and add a new model to stack.
|
||||||
|
# del self.model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# self.asr.new_model_to_stack()
|
||||||
|
self.model.remove_hooks()
|
||||||
|
|
||||||
|
class SimulStreamingASR():
|
||||||
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||||
|
logger.warning(SIMULSTREAMING_LICENSE)
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
|
||||||
|
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||||
|
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||||
|
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
|
||||||
|
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||||
|
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||||
|
self.beams = kwargs.get('beams', 1)
|
||||||
|
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||||
|
self.task = kwargs.get('task', 'transcribe')
|
||||||
|
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||||
|
self.never_fire = kwargs.get('never_fire', False)
|
||||||
|
self.init_prompt = kwargs.get('init_prompt', None)
|
||||||
|
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||||
|
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||||
|
self.warmup_file = kwargs.get('warmup_file', None)
|
||||||
|
self.preload_model_count = kwargs.get('preload_model_count', 1)
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
self.model_path = model_dir
|
||||||
|
elif modelsize is not None:
|
||||||
|
model_mapping = {
|
||||||
|
'tiny': './tiny.pt',
|
||||||
|
'base': './base.pt',
|
||||||
|
'small': './small.pt',
|
||||||
|
'medium': './medium.pt',
|
||||||
|
'medium.en': './medium.en.pt',
|
||||||
|
'large-v1': './large-v1.pt',
|
||||||
|
'base.en': './base.en.pt',
|
||||||
|
'small.en': './small.en.pt',
|
||||||
|
'tiny.en': './tiny.en.pt',
|
||||||
|
'large-v2': './large-v2.pt',
|
||||||
|
'large-v3': './large-v3.pt',
|
||||||
|
'large': './large-v3.pt'
|
||||||
|
}
|
||||||
|
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||||
|
|
||||||
|
# Set up tokenizer for translation if needed
|
||||||
|
if self.task == "translate":
|
||||||
|
self.tokenizer = self.set_translate_task()
|
||||||
|
else:
|
||||||
|
self.tokenizer = None
|
||||||
|
self.cfg = AlignAttConfig(
|
||||||
|
model_path=self.model_path,
|
||||||
|
segment_length=self.segment_length,
|
||||||
|
frame_threshold=self.frame_threshold,
|
||||||
|
language=self.original_language,
|
||||||
|
audio_max_len=self.audio_max_len,
|
||||||
|
audio_min_len=self.audio_min_len,
|
||||||
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
|
decoder_type="beam",
|
||||||
|
beam_size=self.beams,
|
||||||
|
task=self.task,
|
||||||
|
never_fire=self.never_fire,
|
||||||
|
init_prompt=self.init_prompt,
|
||||||
|
max_context_tokens=self.max_context_tokens,
|
||||||
|
static_init_prompt=self.static_init_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||||
|
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||||
|
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
|
||||||
|
warmup_audio = load_file(self.warmup_file)
|
||||||
|
whisper_model.transcribe(warmup_audio, language=self.original_language)
|
||||||
|
return whisper_model
|
||||||
|
|
||||||
|
def get_new_model_instance(self):
|
||||||
|
"""
|
||||||
|
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
||||||
|
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
||||||
|
"""
|
||||||
|
if len(self.models) == 0:
|
||||||
|
self.models.append(self.load_model())
|
||||||
|
new_model = self.models.pop()
|
||||||
|
return new_model
|
||||||
|
# self.models[0]
|
||||||
|
|
||||||
|
def new_model_to_stack(self):
|
||||||
|
self.models.append(self.load_model())
|
||||||
|
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
"""Set up translation task."""
|
||||||
|
return tokenizer.get_tokenizer(
|
||||||
|
multilingual=True,
|
||||||
|
language=self.model.cfg.language,
|
||||||
|
num_languages=self.model.model.num_languages,
|
||||||
|
task="translate"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
"""
|
||||||
|
Warmup is done directly in load_model
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -8,7 +8,7 @@ class SimulWhisperConfig:
|
|||||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||||
model_path: str
|
model_path: str
|
||||||
language: str = field(default="zh")
|
language: str = field(default="zh")
|
||||||
nonspeech_prob: float = 1.0
|
nonspeech_prob: float = 0.5
|
||||||
audio_min_len: float = 1.0
|
audio_min_len: float = 1.0
|
||||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||||
beam_size: int = 5
|
beam_size: int = 5
|
||||||
@@ -24,6 +24,6 @@ class AlignAttConfig(SimulWhisperConfig):
|
|||||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||||
frame_threshold: int = 4
|
frame_threshold: int = 4
|
||||||
rewind_threshold: int = 200
|
rewind_threshold: int = 200
|
||||||
audio_max_len: float = 30.0
|
audio_max_len: float = 20.0
|
||||||
cif_ckpt_path: str = ""
|
cif_ckpt_path: str = ""
|
||||||
never_fire: bool = False
|
never_fire: bool = False
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
|
||||||
|
|
||||||
SimulStreaming is dual-licensed:
|
|
||||||
|
|
||||||
🔹 Non-Commercial Use
|
|
||||||
|
|
||||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
|
||||||
obtain the code through the GitHub repository. This license is **free of charge**
|
|
||||||
and comes with **no obligations** for non-commercial users.
|
|
||||||
|
|
||||||
🔸 Commercial Use
|
|
||||||
|
|
||||||
Understanding who uses SimulStreaming commercially helps us improve and
|
|
||||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
|
||||||
|
|
||||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
|
||||||
are considering to provide commercial licenses either for free or for symbolic
|
|
||||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
|
||||||
|
|
||||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
|
||||||
available.
|
|
||||||
|
|
||||||
✉️ Contact
|
|
||||||
|
|
||||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
|
||||||
@@ -25,6 +25,9 @@ class BeamTokens(Tokens):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|
||||||
|
def as_text(self, tokenizer):
|
||||||
|
return tokenizer.decode(self.tokens)
|
||||||
|
|
||||||
class Logits(Tokens):
|
class Logits(Tokens):
|
||||||
def __init__(self, logits):
|
def __init__(self, logits):
|
||||||
super().__init__(logits)
|
super().__init__(logits)
|
||||||
|
|||||||
5
whisperlivekit/simul_whisper/license_simulstreaming.py
Normal file
5
whisperlivekit/simul_whisper/license_simulstreaming.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
SIMULSTREAMING_LICENSE = f"""
|
||||||
|
SimulStreaming backend is dual-licensed:
|
||||||
|
• Non-Commercial Use: PolyForm Noncommercial License 1.0.0.
|
||||||
|
• Commercial Use: Check SimulStreaming README (github.com/ufal/SimulStreaming) for more details.
|
||||||
|
"""
|
||||||
@@ -10,12 +10,12 @@ from .whisper import load_model, DecodingOptions, tokenizer
|
|||||||
from .config import AlignAttConfig
|
from .config import AlignAttConfig
|
||||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||||
from .whisper.timing import median_filter
|
from .whisper.timing import median_filter
|
||||||
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||||
from .beam import BeamPyTorchInference
|
from .beam import BeamPyTorchInference
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .generation_progress import *
|
from .generation_progress import *
|
||||||
@@ -24,6 +24,7 @@ DEC_PAD = 50257
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import wave
|
||||||
|
|
||||||
# New features added to the original version of Simul-Whisper:
|
# New features added to the original version of Simul-Whisper:
|
||||||
# - large-v3 model support
|
# - large-v3 model support
|
||||||
@@ -32,28 +33,30 @@ import sys
|
|||||||
# - prompt -- static vs. non-static
|
# - prompt -- static vs. non-static
|
||||||
# - context
|
# - context
|
||||||
class PaddedAlignAttWhisper:
|
class PaddedAlignAttWhisper:
|
||||||
def __init__(self, cfg: AlignAttConfig) -> None:
|
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None:
|
||||||
|
self.log_segments = 0
|
||||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||||
self.model = load_model(name=model_name, download_root=model_path)
|
if loaded_model:
|
||||||
|
self.model = loaded_model
|
||||||
|
else:
|
||||||
|
self.model = load_model(name=model_name, download_root=model_path)
|
||||||
|
|
||||||
logger.info(f"Model dimensions: {self.model.dims}")
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
|
|
||||||
decode_options = DecodingOptions(
|
self.decode_options = DecodingOptions(
|
||||||
language = cfg.language,
|
language = cfg.language,
|
||||||
without_timestamps = True,
|
without_timestamps = True,
|
||||||
task=cfg.task
|
task=cfg.task
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer.get_tokenizer(
|
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||||
multilingual=not model_name.endswith(".en"),
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
language=cfg.language,
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
num_languages=self.model.num_languages,
|
|
||||||
task=decode_options.task
|
|
||||||
)
|
|
||||||
self.max_text_len = self.model.dims.n_text_ctx
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.l_hooks = []
|
||||||
|
|
||||||
# model to detect end-of-word boundary at the end of the segment
|
# model to detect end-of-word boundary at the end of the segment
|
||||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||||
@@ -67,7 +70,8 @@ class PaddedAlignAttWhisper:
|
|||||||
t = F.softmax(net_output[1], dim=-1)
|
t = F.softmax(net_output[1], dim=-1)
|
||||||
self.dec_attns.append(t.squeeze(0))
|
self.dec_attns.append(t.squeeze(0))
|
||||||
for b in self.model.decoder.blocks:
|
for b in self.model.decoder.blocks:
|
||||||
b.cross_attn.register_forward_hook(layer_hook)
|
hook = b.cross_attn.register_forward_hook(layer_hook)
|
||||||
|
self.l_hooks.append(hook)
|
||||||
|
|
||||||
self.kv_cache = {}
|
self.kv_cache = {}
|
||||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||||
@@ -80,10 +84,13 @@ class PaddedAlignAttWhisper:
|
|||||||
return self.kv_cache[module.cache_id]
|
return self.kv_cache[module.cache_id]
|
||||||
|
|
||||||
for i,b in enumerate(self.model.decoder.blocks):
|
for i,b in enumerate(self.model.decoder.blocks):
|
||||||
b.attn.key.register_forward_hook(kv_hook)
|
hooks = [
|
||||||
b.attn.value.register_forward_hook(kv_hook)
|
b.attn.key.register_forward_hook(kv_hook),
|
||||||
b.cross_attn.key.register_forward_hook(kv_hook)
|
b.attn.value.register_forward_hook(kv_hook),
|
||||||
b.cross_attn.value.register_forward_hook(kv_hook)
|
b.cross_attn.key.register_forward_hook(kv_hook),
|
||||||
|
b.cross_attn.value.register_forward_hook(kv_hook),
|
||||||
|
]
|
||||||
|
self.l_hooks.extend(hooks)
|
||||||
|
|
||||||
self.align_source = {}
|
self.align_source = {}
|
||||||
self.num_align_heads = 0
|
self.num_align_heads = 0
|
||||||
@@ -95,14 +102,6 @@ class PaddedAlignAttWhisper:
|
|||||||
self.num_align_heads += 1
|
self.num_align_heads += 1
|
||||||
|
|
||||||
|
|
||||||
# init tokens (mandatory prompt)
|
|
||||||
self.initial_tokens = torch.tensor(
|
|
||||||
self.tokenizer.sot_sequence_including_notimestamps,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.model.device).unsqueeze(0)
|
|
||||||
self.initial_token_length = self.initial_tokens.shape[1]
|
|
||||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
|
||||||
|
|
||||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||||
suppress_tokens = [
|
suppress_tokens = [
|
||||||
self.tokenizer.transcribe,
|
self.tokenizer.transcribe,
|
||||||
@@ -121,6 +120,18 @@ class PaddedAlignAttWhisper:
|
|||||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||||
# blank tokens are suppresed for new segments near the line 334
|
# blank tokens are suppresed for new segments near the line 334
|
||||||
|
|
||||||
|
# it's going to be regenerated after lang id
|
||||||
|
self.segments = []
|
||||||
|
self.init_tokens()
|
||||||
|
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
|
||||||
|
if self.cfg.max_context_tokens is None:
|
||||||
|
self.max_context_tokens = self.max_text_len
|
||||||
|
else:
|
||||||
|
self.max_context_tokens = self.cfg.max_context_tokens
|
||||||
|
self.init_context()
|
||||||
|
|
||||||
# decoder type: greedy or beam
|
# decoder type: greedy or beam
|
||||||
if cfg.decoder_type == "greedy":
|
if cfg.decoder_type == "greedy":
|
||||||
@@ -134,17 +145,19 @@ class PaddedAlignAttWhisper:
|
|||||||
self.inference.kv_cache = self.kv_cache
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
|
||||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||||
|
|
||||||
|
def remove_hooks(self):
|
||||||
|
print('remove hook')
|
||||||
|
for hook in self.l_hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
# init state
|
def create_tokenizer(self, language=None):
|
||||||
self.segments = []
|
self.tokenizer = tokenizer.get_tokenizer(
|
||||||
self.tokens = [self.initial_tokens]
|
multilingual=self.tokenizer_is_multilingual,
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
language=language,
|
||||||
|
num_languages=self.model.num_languages,
|
||||||
if self.cfg.max_context_tokens is None:
|
task=self.decode_options.task
|
||||||
self.max_context_tokens = self.max_text_len
|
)
|
||||||
else:
|
|
||||||
self.max_context_tokens = self.cfg.max_context_tokens
|
|
||||||
self.init_context()
|
|
||||||
|
|
||||||
def init_context(self):
|
def init_context(self):
|
||||||
kw = {'tokenizer': self.tokenizer,
|
kw = {'tokenizer': self.tokenizer,
|
||||||
@@ -156,6 +169,19 @@ class PaddedAlignAttWhisper:
|
|||||||
if self.cfg.init_prompt is not None:
|
if self.cfg.init_prompt is not None:
|
||||||
self.context.text += self.cfg.init_prompt
|
self.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
|
def init_tokens(self):
|
||||||
|
logger.debug(f"init tokens, {len(self.segments)}")
|
||||||
|
# init tokens (mandatory prompt)
|
||||||
|
self.initial_tokens = torch.tensor(
|
||||||
|
self.tokenizer.sot_sequence_including_notimestamps,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.model.device).unsqueeze(0)
|
||||||
|
self.initial_token_length = self.initial_tokens.shape[1]
|
||||||
|
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
# self.segments = []
|
||||||
|
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||||
|
self.tokens = [self.initial_tokens]
|
||||||
|
|
||||||
def trim_context(self):
|
def trim_context(self):
|
||||||
logger.info("Trimming context")
|
logger.info("Trimming context")
|
||||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||||
@@ -191,15 +217,20 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
def refresh_segment(self, complete=False):
|
def refresh_segment(self, complete=False):
|
||||||
|
|
||||||
logger.debug("Refreshing segment")
|
logger.debug("Refreshing segment:")
|
||||||
self.tokens = [self.initial_tokens]
|
self.init_tokens()
|
||||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.detected_language = None
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
self.init_context()
|
self.init_context()
|
||||||
logger.debug(f"Context: {self.context}")
|
logger.debug(f"Context: {self.context}")
|
||||||
if not complete and len(self.segments) > 2:
|
if not complete and len(self.segments) > 2:
|
||||||
|
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||||
self.segments = self.segments[-2:]
|
self.segments = self.segments[-2:]
|
||||||
else:
|
else:
|
||||||
|
logger.debug("removing all segments.")
|
||||||
self.segments = []
|
self.segments = []
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
|
||||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
@@ -208,8 +239,6 @@ class PaddedAlignAttWhisper:
|
|||||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _current_tokens(self):
|
def _current_tokens(self):
|
||||||
|
|
||||||
toks = self.tokens
|
toks = self.tokens
|
||||||
@@ -256,16 +285,60 @@ class PaddedAlignAttWhisper:
|
|||||||
removed_len = 0
|
removed_len = 0
|
||||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||||
segments_len = self.segments_len()
|
segments_len = self.segments_len()
|
||||||
while segments_len > self.cfg.audio_max_len:
|
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
removed_len = self.segments[0].shape[0] / 16000
|
removed_len = self.segments[0].shape[0] / 16000
|
||||||
segments_len -= removed_len
|
segments_len -= removed_len
|
||||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||||
|
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||||
self.segments = self.segments[1:]
|
self.segments = self.segments[1:]
|
||||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||||
self.context.append_token_ids(self.tokens[1][0,:])
|
if len(self.tokens) > 1:
|
||||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
self.context.append_token_ids(self.tokens[1][0,:])
|
||||||
|
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||||
return removed_len
|
return removed_len
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
'''clean the cache that stores the attention matrices and kv_cache.
|
||||||
|
It must be called every time after generation with the model.'''
|
||||||
|
# cleaning cache
|
||||||
|
self.dec_attns = []
|
||||||
|
self.kv_cache = {}
|
||||||
|
if self.decoder_type == "beam":
|
||||||
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def lang_id(self, encoder_features):
|
||||||
|
"""Language detection from encoder features.
|
||||||
|
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||||
|
"""
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = encoder_features.shape[0]
|
||||||
|
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||||
|
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
single = encoder_features.ndim == 2
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
### transcription / translation
|
### transcription / translation
|
||||||
|
|
||||||
@@ -273,9 +346,12 @@ class PaddedAlignAttWhisper:
|
|||||||
def infer(self, is_last=False):
|
def infer(self, is_last=False):
|
||||||
new_segment = True
|
new_segment = True
|
||||||
if len(self.segments) == 0:
|
if len(self.segments) == 0:
|
||||||
return []
|
logger.debug("No segments, nothing to do")
|
||||||
|
return [], {}
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
return []
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
|
input_segments = torch.cat(self.segments, dim=0)
|
||||||
|
return [], {}
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
if len(self.segments) > 1:
|
if len(self.segments) > 1:
|
||||||
@@ -283,8 +359,7 @@ class PaddedAlignAttWhisper:
|
|||||||
else:
|
else:
|
||||||
input_segments = self.segments[0]
|
input_segments = self.segments[0]
|
||||||
|
|
||||||
self.trim_context()
|
|
||||||
current_tokens = self._current_tokens()
|
|
||||||
|
|
||||||
# mel + padding to 30s
|
# mel + padding to 30s
|
||||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||||
@@ -295,18 +370,38 @@ class PaddedAlignAttWhisper:
|
|||||||
# the len of actual audio
|
# the len of actual audio
|
||||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||||
|
|
||||||
|
# encode
|
||||||
encoder_feature = self.model.encoder(mel)
|
encoder_feature = self.model.encoder(mel)
|
||||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
|
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
||||||
|
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||||
|
# logger.debug("mel ")
|
||||||
|
if self.cfg.language == "auto" and self.detected_language is None:
|
||||||
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
|
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||||
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
|
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
|
#self.tokenizer.language = top_lan
|
||||||
|
#self.tokenizer.__post_init__()
|
||||||
|
self.create_tokenizer(top_lan)
|
||||||
|
self.detected_language = top_lan
|
||||||
|
self.init_tokens()
|
||||||
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
|
|
||||||
|
self.trim_context()
|
||||||
|
current_tokens = self._current_tokens()
|
||||||
|
#
|
||||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
|
||||||
####################### Decoding loop
|
####################### Decoding loop
|
||||||
logger.info("Decoding loop starts\n")
|
logger.info("Decoding loop starts\n")
|
||||||
|
|
||||||
|
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
||||||
|
completed = False
|
||||||
|
|
||||||
attn_of_alignment_heads = None
|
attn_of_alignment_heads = None
|
||||||
miost_attended_frame = None
|
most_attended_frame = None
|
||||||
|
|
||||||
token_len_before_decoding = current_tokens.shape[1]
|
token_len_before_decoding = current_tokens.shape[1]
|
||||||
|
|
||||||
@@ -412,7 +507,13 @@ class PaddedAlignAttWhisper:
|
|||||||
# for each beam, the most attended frame is:
|
# for each beam, the most attended frame is:
|
||||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||||
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
||||||
|
|
||||||
|
# Calculate absolute timestamps accounting for cumulative offset
|
||||||
|
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||||
|
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
|
||||||
|
|
||||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||||
|
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||||
|
|
||||||
most_attended_frame = most_attended_frames[0].item()
|
most_attended_frame = most_attended_frames[0].item()
|
||||||
|
|
||||||
@@ -515,11 +616,6 @@ class PaddedAlignAttWhisper:
|
|||||||
|
|
||||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
# cleaning cache
|
self._clean_cache()
|
||||||
self.dec_attns = []
|
|
||||||
self.kv_cache = {}
|
|
||||||
if self.decoder_type == "beam":
|
|
||||||
self.inference.kv_cache = self.kv_cache
|
|
||||||
self.token_decoder.reset()
|
|
||||||
|
|
||||||
return new_hypothesis, generation
|
return new_hypothesis, generation
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ def detect_language(
|
|||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
@@ -111,9 +113,6 @@ class DecodingOptions:
|
|||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = True # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
# streaming
|
|
||||||
add_sot: Optional[bool] = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DecodingResult:
|
class DecodingResult:
|
||||||
@@ -513,19 +512,17 @@ class DecodingTask:
|
|||||||
logit_filters: List[LogitFilter]
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.model = model
|
||||||
if self.options.fp16:
|
|
||||||
self.model = model.half()
|
|
||||||
else:
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual, language=language, task=options.task
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
# print(self.options)
|
|
||||||
|
|
||||||
self.n_group: int = options.beam_size or options.best_of or 1
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
self.n_ctx: int = model.dims.n_text_ctx
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
@@ -589,7 +586,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
def _get_initial_tokens(self) -> Tuple[int]:
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
tokens = list(self.sot_sequence)
|
tokens = list(self.sot_sequence)
|
||||||
# print("prefix", prefix)
|
|
||||||
if prefix := self.options.prefix:
|
if prefix := self.options.prefix:
|
||||||
prefix_tokens = (
|
prefix_tokens = (
|
||||||
self.tokenizer.encode(" " + prefix.strip())
|
self.tokenizer.encode(" " + prefix.strip())
|
||||||
@@ -607,15 +604,12 @@ class DecodingTask:
|
|||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else prompt
|
else prompt
|
||||||
)
|
)
|
||||||
# if self.options.add_sot:
|
|
||||||
tokens = (
|
tokens = (
|
||||||
[self.tokenizer.sot_prev]
|
[self.tokenizer.sot_prev]
|
||||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||||
+ tokens
|
+ tokens
|
||||||
)
|
)
|
||||||
#else:
|
|
||||||
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
|
|
||||||
# print("return", tokens)
|
|
||||||
return tuple(tokens)
|
return tuple(tokens)
|
||||||
|
|
||||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
@@ -663,7 +657,7 @@ class DecodingTask:
|
|||||||
if audio_features.dtype != (
|
if audio_features.dtype != (
|
||||||
torch.float16 if self.options.fp16 else torch.float32
|
torch.float16 if self.options.fp16 else torch.float32
|
||||||
):
|
):
|
||||||
raise TypeError(
|
return TypeError(
|
||||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -689,10 +683,9 @@ class DecodingTask:
|
|||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(self.sample_len): # 最多循环448次
|
for i in range(self.sample_len):
|
||||||
# print("in decode main loop", i , tokens[0].tolist())
|
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
# print(logits)
|
|
||||||
if (
|
if (
|
||||||
i == 0 and self.tokenizer.no_speech is not None
|
i == 0 and self.tokenizer.no_speech is not None
|
||||||
): # save no_speech_probs
|
): # save no_speech_probs
|
||||||
@@ -724,7 +717,7 @@ class DecodingTask:
|
|||||||
|
|
||||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
# print("initial_tokens", self.initial_tokens)
|
|
||||||
# detect language if requested, overwriting the language token
|
# detect language if requested, overwriting the language token
|
||||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
if self.options.task == "lang_id":
|
if self.options.task == "lang_id":
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
@@ -37,26 +36,27 @@ class ModelDimensions:
|
|||||||
n_text_layer: int
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
# class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
# def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
# class Linear(nn.Linear):
|
|
||||||
# def forward(self, x: Tensor) -> Tensor:
|
|
||||||
# return F.linear(
|
|
||||||
# x,
|
|
||||||
# self.weight.to(x.dtype),
|
|
||||||
# None if self.bias is None else self.bias.to(x.dtype),
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
# class Conv1d(nn.Conv1d):
|
class Linear(nn.Linear):
|
||||||
# def _conv_forward(
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
return F.linear(
|
||||||
# ) -> Tensor:
|
x,
|
||||||
# return super()._conv_forward(
|
self.weight.to(x.dtype),
|
||||||
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
None if self.bias is None else self.bias.to(x.dtype),
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def _conv_forward(
|
||||||
|
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
|
return super()._conv_forward(
|
||||||
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sinusoids(length, channels, max_timescale=10000):
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
@@ -67,21 +67,30 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
import sys ## this is mine, for debugging
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||||
|
|
||||||
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
|
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int, cache_id: str):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.query = nn.Linear(n_state, n_state)
|
self.query = Linear(n_state, n_state)
|
||||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
self.key.cache_id = f"{cache_id}_key"
|
self.value = Linear(n_state, n_state)
|
||||||
self.value = nn.Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
self.value.cache_id = f"{cache_id}_value"
|
|
||||||
self.out = nn.Linear(n_state, n_state)
|
|
||||||
self.cache_id = cache_id
|
self.cache_id = cache_id
|
||||||
|
self.key.cache_id = f"{cache_id}_key"
|
||||||
|
self.value.cache_id = f"{cache_id}_value"
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -90,45 +99,21 @@ class MultiHeadAttention(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
#print("MultiHeadAttention forward",file=sys.stderr)
|
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
|
|
||||||
# print(mask, kv_cache, xa, file=sys.stderr)
|
|
||||||
|
|
||||||
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
k = self.key(x if xa is None else xa)
|
k = self.key(x if xa is None else xa)
|
||||||
v = self.value(x if xa is None else xa)
|
v = self.value(x if xa is None else xa)
|
||||||
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
|
|
||||||
# if kv_cache is not None:
|
|
||||||
# print(kv_cache.keys())
|
|
||||||
else:
|
else:
|
||||||
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
# if kv_cache is not None:
|
k = kv_cache[self.key]
|
||||||
# print(kv_cache.keys())
|
v = kv_cache[self.value]
|
||||||
k = kv_cache[self.key.cache_id]
|
|
||||||
v = kv_cache[self.value.cache_id]
|
|
||||||
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv), qk
|
return self.out(wv), qk
|
||||||
|
|
||||||
# def qkv_attention(
|
|
||||||
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
|
||||||
# ):
|
|
||||||
# n_batch, n_ctx, n_state = q.shape
|
|
||||||
# scale = (n_state // self.n_head) ** -0.25
|
|
||||||
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
|
||||||
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
|
||||||
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# qk = q @ k
|
|
||||||
# if mask is not None:
|
|
||||||
# qk = qk + mask[:n_ctx, :n_ctx]
|
|
||||||
# # qk = qk.float()
|
|
||||||
|
|
||||||
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
|
|
||||||
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
|
||||||
|
|
||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@@ -158,21 +143,22 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||||
self.attn_ln = nn.LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
self.cross_attn = (
|
||||||
|
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
)
|
||||||
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||||
)
|
)
|
||||||
self.mlp_ln = nn.LayerNorm(n_state)
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -181,8 +167,6 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
# print("ResidualAttentionBlock forward",file=sys.stderr)
|
|
||||||
# print(x.shape, file=sys.stderr)
|
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
@@ -195,44 +179,32 @@ class AudioEncoder(nn.Module):
|
|||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||||
)
|
)
|
||||||
self.ln_post = nn.LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor, return_layer_results: bool=False):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
the mel spectrogram of the audio
|
the mel spectrogram of the audio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = F.gelu(self.conv1(x))
|
x = F.gelu(self.conv1(x))
|
||||||
x = F.gelu(self.conv2(x))
|
x = F.gelu(self.conv2(x))
|
||||||
x = x.permute(0, 2, 1) # BDT -> BTD
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
# 两层卷积,2倍降采样
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
# 最终剩下1500帧
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
|
|
||||||
|
|
||||||
layer_results = []
|
|
||||||
i = 0
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
# print(f"encoder layer {i}")
|
|
||||||
x = block(x)
|
x = block(x)
|
||||||
layer_results.append(x)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
if return_layer_results:
|
|
||||||
return x, layer_results
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
@@ -250,7 +222,7 @@ class TextDecoder(nn.Module):
|
|||||||
for i in range(n_layer)
|
for i in range(n_layer)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.ln = nn.LayerNorm(n_state)
|
self.ln = LayerNorm(n_state)
|
||||||
|
|
||||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
self.register_buffer("mask", mask, persistent=False)
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
@@ -262,22 +234,20 @@ class TextDecoder(nn.Module):
|
|||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
the encoded audio features to be attended on
|
the encoded audio features to be attended on
|
||||||
"""
|
"""
|
||||||
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
x = (
|
x = (
|
||||||
self.token_embedding(x)
|
self.token_embedding(x)
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
)
|
)
|
||||||
# x = x.to(xa.dtype)
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
i = 0
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
# print(f"decoder layer {i}")
|
|
||||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
i += 1
|
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
|
logits = (
|
||||||
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@@ -300,7 +270,8 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
)
|
)
|
||||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
@@ -320,15 +291,11 @@ class Whisper(nn.Module):
|
|||||||
return self.encoder(mel)
|
return self.encoder(mel)
|
||||||
|
|
||||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
|
||||||
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
|
|
||||||
return self.decoder(tokens, audio_features)
|
return self.decoder(tokens, audio_features)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
# mel = mel.to(self.decoder.ln.weight.dtype)
|
|
||||||
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -343,7 +310,6 @@ class Whisper(nn.Module):
|
|||||||
def num_languages(self):
|
def num_languages(self):
|
||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
# 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
|
|||||||
@@ -30,15 +30,19 @@ def remove_symbols_and_diacritics(s: str, keep=""):
|
|||||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
"""
|
"""
|
||||||
return "".join(
|
return "".join(
|
||||||
c
|
(
|
||||||
if c in keep
|
c
|
||||||
else ADDITIONAL_DIACRITICS[c]
|
if c in keep
|
||||||
if c in ADDITIONAL_DIACRITICS
|
else (
|
||||||
else ""
|
ADDITIONAL_DIACRITICS[c]
|
||||||
if unicodedata.category(c) == "Mn"
|
if c in ADDITIONAL_DIACRITICS
|
||||||
else " "
|
else (
|
||||||
if unicodedata.category(c)[0] in "MSP"
|
""
|
||||||
else c
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
for c in unicodedata.normalize("NFKD", s)
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -56,9 +56,8 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
|||||||
|
|
||||||
@numba.jit(nopython=True)
|
@numba.jit(nopython=True)
|
||||||
def backtrace(trace: np.ndarray):
|
def backtrace(trace: np.ndarray):
|
||||||
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
|
i = trace.shape[0] - 1
|
||||||
j = trace.shape[1] - 1 # j=M
|
j = trace.shape[1] - 1
|
||||||
# 边界点其实无意义?
|
|
||||||
trace[0, :] = 2
|
trace[0, :] = 2
|
||||||
trace[:, 0] = 1
|
trace[:, 0] = 1
|
||||||
|
|
||||||
@@ -83,8 +82,8 @@ def backtrace(trace: np.ndarray):
|
|||||||
@numba.jit(nopython=True, parallel=True)
|
@numba.jit(nopython=True, parallel=True)
|
||||||
def dtw_cpu(x: np.ndarray):
|
def dtw_cpu(x: np.ndarray):
|
||||||
N, M = x.shape
|
N, M = x.shape
|
||||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||||
|
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
for j in range(1, M + 1):
|
for j in range(1, M + 1):
|
||||||
@@ -118,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
|||||||
x_skew = x_skew.T.contiguous()
|
x_skew = x_skew.T.contiguous()
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
cost[0, 0] = 0
|
cost[0, 0] = 0
|
||||||
cost = cost.cuda()
|
cost = cost.to(x.device)
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
dtw_kernel[(1,)](
|
||||||
@@ -192,21 +191,19 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 进行前传,获得token概率
|
from .model import disable_sdpa
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||||
text_token_probs = text_token_probs.tolist()
|
text_token_probs = text_token_probs.tolist()
|
||||||
|
|
||||||
# 移除钩子
|
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
# print(model.alignment_heads)
|
|
||||||
# exit(0)
|
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
@@ -215,18 +212,9 @@ def find_alignment(
|
|||||||
weights = median_filter(weights, medfilt_width)
|
weights = median_filter(weights, medfilt_width)
|
||||||
|
|
||||||
matrix = weights.mean(axis=0)
|
matrix = weights.mean(axis=0)
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
text_indices, time_indices = dtw(-matrix)
|
text_indices, time_indices = dtw(-matrix)
|
||||||
|
|
||||||
print("num_frames", num_frames)
|
|
||||||
print("attention", matrix.shape, matrix[:5, :5])
|
|
||||||
print("text_indices", text_indices)
|
|
||||||
print("time", time_indices)
|
|
||||||
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
|
|
||||||
print("eot", tokenizer.eot)
|
|
||||||
|
|
||||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||||
if len(word_tokens) <= 1:
|
if len(word_tokens) <= 1:
|
||||||
# return on eot only
|
# return on eot only
|
||||||
@@ -238,9 +226,7 @@ def find_alignment(
|
|||||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
# print("jumps", jumps, jumps.shape)
|
|
||||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||||
# print("jump_times", jump_times)
|
|
||||||
start_times = jump_times[word_boundaries[:-1]]
|
start_times = jump_times[word_boundaries[:-1]]
|
||||||
end_times = jump_times[word_boundaries[1:]]
|
end_times = jump_times[word_boundaries[1:]]
|
||||||
word_probabilities = [
|
word_probabilities = [
|
||||||
@@ -315,6 +301,7 @@ def add_word_timestamps(
|
|||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
|||||||
@@ -1,501 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from whisper.audio import (
|
|
||||||
FRAMES_PER_SECOND,
|
|
||||||
HOP_LENGTH,
|
|
||||||
N_FRAMES,
|
|
||||||
N_SAMPLES,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
log_mel_spectrogram,
|
|
||||||
pad_or_trim,
|
|
||||||
)
|
|
||||||
from whisper.decoding import DecodingOptions, DecodingResult
|
|
||||||
from whisper.timing import add_word_timestamps
|
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|
||||||
from whisper.utils import (
|
|
||||||
exact_div,
|
|
||||||
format_timestamp,
|
|
||||||
get_writer,
|
|
||||||
make_safe,
|
|
||||||
optional_float,
|
|
||||||
optional_int,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from whisper.model import Whisper
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
|
||||||
model: "Whisper",
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
|
||||||
*,
|
|
||||||
verbose: Optional[bool] = None,
|
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
|
||||||
condition_on_previous_text: bool = True,
|
|
||||||
initial_prompt: Optional[str] = None,
|
|
||||||
word_timestamps: bool = False,
|
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
|
||||||
**decode_options,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribe an audio file using Whisper
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model: Whisper
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
verbose: bool
|
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
|
||||||
If False, displays minimal details. If None, does not display anything
|
|
||||||
|
|
||||||
temperature: Union[float, Tuple[float, ...]]
|
|
||||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
|
||||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
||||||
|
|
||||||
compression_ratio_threshold: float
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
logprob_threshold: float
|
|
||||||
If the average log probability over sampled tokens is below this value, treat as failed
|
|
||||||
|
|
||||||
no_speech_threshold: float
|
|
||||||
If the no_speech probability is higher than this value AND the average log probability
|
|
||||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
||||||
|
|
||||||
condition_on_previous_text: bool
|
|
||||||
if True, the previous output of the model is provided as a prompt for the next window;
|
|
||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
||||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
||||||
|
|
||||||
word_timestamps: bool
|
|
||||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
|
||||||
and include the timestamps for each word in each segment.
|
|
||||||
|
|
||||||
prepend_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
|
||||||
|
|
||||||
append_punctuations: str
|
|
||||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
|
||||||
|
|
||||||
initial_prompt: Optional[str]
|
|
||||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
|
||||||
to make it more likely to predict those word correctly.
|
|
||||||
|
|
||||||
decode_options: dict
|
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
|
||||||
# print("HACKED")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
|
||||||
if model.device == torch.device("cpu"):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
|
||||||
if dtype == torch.float16:
|
|
||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
decode_options["fp16"] = False
|
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
|
||||||
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
|
||||||
# mel = pad_or_trim(mel, 3000)
|
|
||||||
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
|
||||||
|
|
||||||
# 判断语种
|
|
||||||
if decode_options.get("language", None) is None:
|
|
||||||
# 如果是单语种模型,直接设成英文
|
|
||||||
if not model.is_multilingual:
|
|
||||||
decode_options["language"] = "en"
|
|
||||||
# 否则需要前传一次
|
|
||||||
else:
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
||||||
)
|
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
|
||||||
if verbose is not None:
|
|
||||||
print(
|
|
||||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
language: str = decode_options["language"]
|
|
||||||
task: str = decode_options.get("task", "transcribe")
|
|
||||||
# 输出编码器
|
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
|
||||||
|
|
||||||
# 词级别时间戳
|
|
||||||
if word_timestamps and task == "translate":
|
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
|
||||||
|
|
||||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
|
||||||
temperatures = (
|
|
||||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
|
||||||
)
|
|
||||||
decode_result = None
|
|
||||||
|
|
||||||
for t in temperatures:
|
|
||||||
kwargs = {**decode_options}
|
|
||||||
if t > 0:
|
|
||||||
# disable beam_size and patience when t > 0
|
|
||||||
kwargs.pop("beam_size", None)
|
|
||||||
kwargs.pop("patience", None)
|
|
||||||
else:
|
|
||||||
# disable best_of when t == 0
|
|
||||||
kwargs.pop("best_of", None)
|
|
||||||
|
|
||||||
options = DecodingOptions(**kwargs, temperature=t)
|
|
||||||
decode_result = model.decode(segment, options)
|
|
||||||
|
|
||||||
# 几种解码可能失败的情况。这些情况下会重复解码
|
|
||||||
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
|
||||||
needs_fallback = False
|
|
||||||
if (
|
|
||||||
compression_ratio_threshold is not None
|
|
||||||
and decode_result.compression_ratio > compression_ratio_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # too repetitive
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and decode_result.avg_logprob < logprob_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = True # average log probability is too low
|
|
||||||
if (
|
|
||||||
no_speech_threshold is not None
|
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
|
||||||
):
|
|
||||||
needs_fallback = False # silence
|
|
||||||
if not needs_fallback:
|
|
||||||
break
|
|
||||||
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
|
||||||
# t,
|
|
||||||
# decode_result.compression_ratio, compression_ratio_threshold,
|
|
||||||
# -decode_result.avg_logprob, -logprob_threshold,
|
|
||||||
# decode_result.no_speech_prob, no_speech_threshold
|
|
||||||
# ))
|
|
||||||
|
|
||||||
return decode_result
|
|
||||||
|
|
||||||
seek = 0
|
|
||||||
input_stride = exact_div(
|
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
|
||||||
) # mel frames per output token: 2
|
|
||||||
# 这里output token指的应该是CNN输出的那个东西
|
|
||||||
|
|
||||||
time_precision = (
|
|
||||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
|
||||||
) # time per output token: 0.02 (seconds)
|
|
||||||
all_tokens = []
|
|
||||||
all_segments = []
|
|
||||||
prompt_reset_since = 0
|
|
||||||
|
|
||||||
if initial_prompt is not None:
|
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
|
||||||
else:
|
|
||||||
initial_prompt_tokens = []
|
|
||||||
|
|
||||||
def new_segment(
|
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
|
||||||
):
|
|
||||||
tokens = tokens.tolist()
|
|
||||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
|
||||||
return {
|
|
||||||
"seek": seek,
|
|
||||||
"start": start,
|
|
||||||
"end": end,
|
|
||||||
"text": tokenizer.decode(text_tokens),
|
|
||||||
"tokens": tokens,
|
|
||||||
"temperature": result.temperature,
|
|
||||||
"avg_logprob": result.avg_logprob,
|
|
||||||
"compression_ratio": result.compression_ratio,
|
|
||||||
"no_speech_prob": result.no_speech_prob,
|
|
||||||
}
|
|
||||||
|
|
||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
|
||||||
with tqdm.tqdm(
|
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
|
||||||
) as pbar:
|
|
||||||
last_speech_timestamp = 0.0
|
|
||||||
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
|
||||||
# print("seek segments", seek, content_frames)
|
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
|
||||||
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
|
||||||
mel_segment = mel[:, seek:]
|
|
||||||
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
|
||||||
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
|
||||||
tokens = torch.tensor(result.tokens)
|
|
||||||
|
|
||||||
# 跳过静音部分
|
|
||||||
if no_speech_threshold is not None:
|
|
||||||
# no voice activity check
|
|
||||||
should_skip = result.no_speech_prob > no_speech_threshold
|
|
||||||
if (
|
|
||||||
logprob_threshold is not None
|
|
||||||
and result.avg_logprob > logprob_threshold
|
|
||||||
):
|
|
||||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
|
||||||
should_skip = False
|
|
||||||
|
|
||||||
if should_skip:
|
|
||||||
seek += segment_size # fast-forward to the next segment boundary
|
|
||||||
continue
|
|
||||||
|
|
||||||
previous_seek = seek
|
|
||||||
current_segments = []
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
|
||||||
timestamp_tokens[-1] = False
|
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
|
||||||
|
|
||||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
||||||
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
|
||||||
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
|
||||||
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
|
||||||
# 多个的话指向第二个 那如果有三个怎么办?
|
|
||||||
# 否则是个0维tensor
|
|
||||||
|
|
||||||
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
|
||||||
if len(consecutive) > 0:
|
|
||||||
# if the output contains two consecutive timestamp tokens
|
|
||||||
slices = consecutive.tolist()
|
|
||||||
if single_timestamp_ending:
|
|
||||||
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
|
||||||
# print("many sentenses", consecutive)
|
|
||||||
last_slice = 0
|
|
||||||
for current_slice in slices:
|
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
|
||||||
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
|
||||||
start_timestamp_pos = (
|
|
||||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
end_timestamp_pos = (
|
|
||||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
# 获取一个新的语音段
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset + start_timestamp_pos * time_precision,
|
|
||||||
end=time_offset + end_timestamp_pos * time_precision,
|
|
||||||
tokens=sliced_tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
last_slice = current_slice
|
|
||||||
|
|
||||||
if single_timestamp_ending:
|
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
|
||||||
seek += segment_size
|
|
||||||
else:
|
|
||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
|
||||||
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
|
||||||
# 换句话说就是针对30s长的chunk的语音设计的
|
|
||||||
last_timestamp_pos = (
|
|
||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
seek += last_timestamp_pos * input_stride
|
|
||||||
else:
|
|
||||||
duration = segment_duration
|
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
||||||
# print(timestamps)
|
|
||||||
if (
|
|
||||||
len(timestamps) > 0
|
|
||||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
|
||||||
):
|
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
||||||
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
|
||||||
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
|
||||||
last_timestamp_pos = (
|
|
||||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
|
||||||
)
|
|
||||||
duration = last_timestamp_pos * time_precision
|
|
||||||
|
|
||||||
current_segments.append(
|
|
||||||
new_segment(
|
|
||||||
start=time_offset,
|
|
||||||
end=time_offset + duration,
|
|
||||||
tokens=tokens,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seek += segment_size
|
|
||||||
|
|
||||||
# 每个token有自己的时间戳
|
|
||||||
if word_timestamps:
|
|
||||||
add_word_timestamps(
|
|
||||||
segments=current_segments,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
mel=mel_segment,
|
|
||||||
num_frames=segment_size,
|
|
||||||
prepend_punctuations=prepend_punctuations,
|
|
||||||
append_punctuations=append_punctuations,
|
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
|
||||||
)
|
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
|
||||||
]
|
|
||||||
if len(word_end_timestamps) > 0:
|
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
|
||||||
)
|
|
||||||
if seek_shift > 0:
|
|
||||||
seek = previous_seek + seek_shift
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
for segment in current_segments:
|
|
||||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
|
||||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
|
||||||
print(make_safe(line))
|
|
||||||
|
|
||||||
# if a segment is instantaneous or does not contain text, clear it
|
|
||||||
for i, segment in enumerate(current_segments):
|
|
||||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
|
||||||
segment["text"] = ""
|
|
||||||
segment["tokens"] = []
|
|
||||||
segment["words"] = []
|
|
||||||
|
|
||||||
# 更新结果
|
|
||||||
all_segments.extend(
|
|
||||||
[
|
|
||||||
{"id": i, **segment}
|
|
||||||
for i, segment in enumerate(
|
|
||||||
current_segments, start=len(all_segments)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
all_tokens.extend(
|
|
||||||
[token for segment in current_segments for token in segment["tokens"]]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
|
||||||
# do not feed the prompt tokens if a high temperature was used
|
|
||||||
prompt_reset_since = len(all_tokens)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar.update(min(content_frames, seek) - previous_seek)
|
|
||||||
|
|
||||||
# print("太长了")
|
|
||||||
# break
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
|
||||||
segments=all_segments,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
|
||||||
from . import available_models
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
|
||||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
||||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
|
||||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
|
||||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
|
||||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
|
||||||
model_name: str = args.pop("model")
|
|
||||||
model_dir: str = args.pop("model_dir")
|
|
||||||
output_dir: str = args.pop("output_dir")
|
|
||||||
output_format: str = args.pop("output_format")
|
|
||||||
device: str = args.pop("device")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
|
||||||
if args["language"] is not None:
|
|
||||||
warnings.warn(
|
|
||||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
|
||||||
)
|
|
||||||
args["language"] = "en"
|
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
||||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
|
||||||
else:
|
|
||||||
temperature = [temperature]
|
|
||||||
|
|
||||||
if (threads := args.pop("threads")) > 0:
|
|
||||||
torch.set_num_threads(threads)
|
|
||||||
|
|
||||||
from . import load_model
|
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
|
||||||
if not args["word_timestamps"]:
|
|
||||||
for option in word_options:
|
|
||||||
if args[option]:
|
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
|
||||||
for audio_path in args.pop("audio"):
|
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
||||||
writer(result, audio_path, writer_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -22,6 +23,7 @@ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
exact_div,
|
exact_div,
|
||||||
format_timestamp,
|
format_timestamp,
|
||||||
|
get_end,
|
||||||
get_writer,
|
get_writer,
|
||||||
make_safe,
|
make_safe,
|
||||||
optional_float,
|
optional_float,
|
||||||
@@ -44,9 +46,12 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
carry_initial_prompt: bool = False,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -98,15 +103,27 @@ def transcribe(
|
|||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
to make it more likely to predict those word correctly.
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
|
carry_initial_prompt: bool
|
||||||
|
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||||
|
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||||
|
left-sliced to make space.
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
"""
|
"""
|
||||||
# print("transcribe")
|
|
||||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||||
if model.device == torch.device("cpu"):
|
if model.device == torch.device("cpu"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -119,8 +136,9 @@ def transcribe(
|
|||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@@ -131,7 +149,6 @@ def transcribe(
|
|||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||||
)
|
)
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
# print(mel_segment.shape)
|
|
||||||
_, probs = model.detect_language(mel_segment)
|
_, probs = model.detect_language(mel_segment)
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
@@ -141,7 +158,25 @@ def transcribe(
|
|||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
@@ -179,6 +214,8 @@ def transcribe(
|
|||||||
if (
|
if (
|
||||||
no_speech_threshold is not None
|
no_speech_threshold is not None
|
||||||
and decode_result.no_speech_prob > no_speech_threshold
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
|
and logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
@@ -186,7 +223,8 @@ def transcribe(
|
|||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
input_stride = exact_div(
|
input_stride = exact_div(
|
||||||
N_FRAMES, model.dims.n_audio_ctx
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@@ -197,9 +235,11 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
initial_prompt_tokens = []
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
@@ -225,16 +265,33 @@ def transcribe(
|
|||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
while seek < content_frames:
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment_size = min(N_FRAMES, content_frames - seek)
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
# print("melshape", mel_segment.shape)
|
if carry_initial_prompt:
|
||||||
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||||
|
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||||
|
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||||
|
else:
|
||||||
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
@@ -255,6 +312,30 @@ def transcribe(
|
|||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
@@ -317,9 +398,7 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
# print("word_timestamps, ", word_timestamps)
|
|
||||||
if word_timestamps:
|
if word_timestamps:
|
||||||
# print("=========run timestamps here=========")
|
|
||||||
add_word_timestamps(
|
add_word_timestamps(
|
||||||
segments=current_segments,
|
segments=current_segments,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -330,17 +409,71 @@ def transcribe(
|
|||||||
append_punctuations=append_punctuations,
|
append_punctuations=append_punctuations,
|
||||||
last_speech_timestamp=last_speech_timestamp,
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
)
|
)
|
||||||
word_end_timestamps = [
|
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
if not single_timestamp_ending:
|
||||||
]
|
last_word_end = get_end(current_segments)
|
||||||
if len(word_end_timestamps) > 0:
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
last_speech_timestamp = word_end_timestamps[-1]
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
||||||
seek_shift = round(
|
# skip silence before possible hallucinations
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
if hallucination_silence_threshold is not None:
|
||||||
)
|
threshold = hallucination_silence_threshold
|
||||||
if seek_shift > 0:
|
if not single_timestamp_ending:
|
||||||
seek = previous_seek + seek_shift
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
@@ -384,10 +517,17 @@ def transcribe(
|
|||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
|
def valid_model_name(name):
|
||||||
|
if name in available_models() or os.path.exists(name):
|
||||||
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
@@ -405,6 +545,8 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||||
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
@@ -418,7 +560,10 @@ def cli():
|
|||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@@ -450,17 +595,28 @@ def cli():
|
|||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
word_options = [
|
||||||
|
"highlight_words",
|
||||||
|
"max_line_count",
|
||||||
|
"max_line_width",
|
||||||
|
"max_words_per_line",
|
||||||
|
]
|
||||||
if not args["word_timestamps"]:
|
if not args["word_timestamps"]:
|
||||||
for option in word_options:
|
for option in word_options:
|
||||||
if args[option]:
|
if args[option]:
|
||||||
parser.error(f"--{option} requires --word_timestamps True")
|
parser.error(f"--{option} requires --word_timestamps True")
|
||||||
if args["max_line_count"] and not args["max_line_width"]:
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
|
if args["max_words_per_line"] and args["max_line_width"]:
|
||||||
|
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
try:
|
||||||
writer(result, audio_path, writer_args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
writer(result, audio_path, **writer_args)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
|||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
kernel.src = kernel.src.replace(
|
new_kernel = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
@@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace(
|
|
||||||
|
new_kernel = new_kernel.replace(
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
@@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
|
|
||||||
|
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||||
|
kernel._unsafe_update_src(new_kernel)
|
||||||
|
kernel.hash = None
|
||||||
|
else:
|
||||||
|
kernel.src = new_kernel
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
@@ -68,13 +68,29 @@ def format_timestamp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str):
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
def __call__(
|
||||||
|
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
output_path = os.path.join(
|
output_path = os.path.join(
|
||||||
@@ -82,16 +98,20 @@ class ResultWriter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WriteTXT(ResultWriter):
|
class WriteTXT(ResultWriter):
|
||||||
extension: str = "txt"
|
extension: str = "txt"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
@@ -100,48 +120,76 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
always_include_hours: bool
|
always_include_hours: bool
|
||||||
decimal_marker: str
|
decimal_marker: str
|
||||||
|
|
||||||
def iterate_result(self, result: dict, options: dict):
|
def iterate_result(
|
||||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
self,
|
||||||
max_line_count: Optional[int] = options["max_line_count"]
|
result: dict,
|
||||||
highlight_words: bool = options["highlight_words"]
|
options: Optional[dict] = None,
|
||||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
*,
|
||||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
max_line_width: Optional[int] = None,
|
||||||
|
max_line_count: Optional[int] = None,
|
||||||
|
highlight_words: bool = False,
|
||||||
|
max_words_per_line: Optional[int] = None,
|
||||||
|
):
|
||||||
|
options = options or {}
|
||||||
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
|
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||||
|
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||||
|
preserve_segments = max_line_count is None or max_line_width is None
|
||||||
|
max_line_width = max_line_width or 1000
|
||||||
|
max_words_per_line = max_words_per_line or 1000
|
||||||
|
|
||||||
def iterate_subtitles():
|
def iterate_subtitles():
|
||||||
line_len = 0
|
line_len = 0
|
||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: List[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
chunk_index = 0
|
||||||
timing = original_timing.copy()
|
words_count = max_words_per_line
|
||||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
while chunk_index < len(segment["words"]):
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
remaining_words = len(segment["words"]) - chunk_index
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
words_count = remaining_words
|
||||||
# line continuation
|
for i, original_timing in enumerate(
|
||||||
line_len += len(timing["word"])
|
segment["words"][chunk_index : chunk_index + words_count]
|
||||||
else:
|
):
|
||||||
# new line
|
timing = original_timing.copy()
|
||||||
timing["word"] = timing["word"].strip()
|
long_pause = (
|
||||||
|
not preserve_segments and timing["start"] - last > 3.0
|
||||||
|
)
|
||||||
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
if (
|
if (
|
||||||
len(subtitle) > 0
|
line_len > 0
|
||||||
and max_line_count is not None
|
and has_room
|
||||||
and (long_pause or line_count >= max_line_count)
|
and not long_pause
|
||||||
or seg_break
|
and not seg_break
|
||||||
):
|
):
|
||||||
# subtitle break
|
# line continuation
|
||||||
yield subtitle
|
line_len += len(timing["word"])
|
||||||
subtitle = []
|
else:
|
||||||
line_count = 1
|
# new line
|
||||||
elif line_len > 0:
|
timing["word"] = timing["word"].strip()
|
||||||
# line break
|
if (
|
||||||
line_count += 1
|
len(subtitle) > 0
|
||||||
timing["word"] = "\n" + timing["word"]
|
and max_line_count is not None
|
||||||
line_len = len(timing["word"].strip())
|
and (long_pause or line_count >= max_line_count)
|
||||||
subtitle.append(timing)
|
or seg_break
|
||||||
last = timing["start"]
|
):
|
||||||
|
# subtitle break
|
||||||
|
yield subtitle
|
||||||
|
subtitle = []
|
||||||
|
line_count = 1
|
||||||
|
elif line_len > 0:
|
||||||
|
# line break
|
||||||
|
line_count += 1
|
||||||
|
timing["word"] = "\n" + timing["word"]
|
||||||
|
line_len = len(timing["word"].strip())
|
||||||
|
subtitle.append(timing)
|
||||||
|
last = timing["start"]
|
||||||
|
chunk_index += max_words_per_line
|
||||||
if len(subtitle) > 0:
|
if len(subtitle) > 0:
|
||||||
yield subtitle
|
yield subtitle
|
||||||
|
|
||||||
@@ -161,9 +209,11 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, "".join(
|
||||||
[
|
[
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
(
|
||||||
if j == i
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
else word
|
if j == i
|
||||||
|
else word
|
||||||
|
)
|
||||||
for j, word in enumerate(all_words)
|
for j, word in enumerate(all_words)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -190,9 +240,11 @@ class WriteVTT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = False
|
always_include_hours: bool = False
|
||||||
decimal_marker: str = "."
|
decimal_marker: str = "."
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("WEBVTT\n", file=file)
|
print("WEBVTT\n", file=file)
|
||||||
for start, end, text in self.iterate_result(result, options):
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -201,9 +253,11 @@ class WriteSRT(SubtitlesWriter):
|
|||||||
always_include_hours: bool = True
|
always_include_hours: bool = True
|
||||||
decimal_marker: str = ","
|
decimal_marker: str = ","
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for i, (start, end, text) in enumerate(
|
for i, (start, end, text) in enumerate(
|
||||||
self.iterate_result(result, options), start=1
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
):
|
):
|
||||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
@@ -220,7 +274,9 @@ class WriteTSV(ResultWriter):
|
|||||||
|
|
||||||
extension: str = "tsv"
|
extension: str = "tsv"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
print("start", "end", "text", sep="\t", file=file)
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
@@ -231,7 +287,9 @@ class WriteTSV(ResultWriter):
|
|||||||
class WriteJSON(ResultWriter):
|
class WriteJSON(ResultWriter):
|
||||||
extension: str = "json"
|
extension: str = "json"
|
||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
json.dump(result, file)
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
@@ -249,9 +307,11 @@ def get_writer(
|
|||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
def write_all(result: dict, file: TextIO, options: dict):
|
def write_all(
|
||||||
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
for writer in all_writers:
|
for writer in all_writers:
|
||||||
writer(result, file, options)
|
writer(result, file, options, **kwargs)
|
||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "20230918"
|
__version__ = "20250625"
|
||||||
|
|||||||
@@ -29,4 +29,8 @@ class SpeakerSegment(TimedText):
|
|||||||
"""Represents a segment of audio attributed to a specific speaker.
|
"""Represents a segment of audio attributed to a specific speaker.
|
||||||
No text nor probability is associated with this segment.
|
No text nor probability is associated with this segment.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Silence():
|
||||||
|
duration: float
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
import torch
|
|
||||||
import sys
|
|
||||||
class TokenBuffer:
|
|
||||||
|
|
||||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
|
||||||
self.text = text
|
|
||||||
self.prefix_token_ids = prefix_token_ids
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("Tokenizer is not set.")
|
|
||||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
|
||||||
|
|
||||||
def as_tensor(self, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = self.device
|
|
||||||
if device is None:
|
|
||||||
raise ValueError("Device is not set.")
|
|
||||||
tok_ids = self.as_token_ids()
|
|
||||||
return torch.tensor(tok_ids,
|
|
||||||
dtype=torch.long, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
def as_tensor_beam(self, beam, device=None):
|
|
||||||
t = self.as_tensor(device=device)
|
|
||||||
return t.repeat_interleave(beam, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def as_text(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def empty(*a, **kw):
|
|
||||||
return TokenBuffer(*a,**kw)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text(text, *a, **kw):
|
|
||||||
return TokenBuffer(*a, text=text, **kw)
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return self.text is None or self.text == ""
|
|
||||||
|
|
||||||
def trim_words(self, num=1, after=0):
|
|
||||||
'''
|
|
||||||
num: how many words to trim from the beginning
|
|
||||||
after: how many characters to skip (length of the static prompt)
|
|
||||||
'''
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
|
|
||||||
ids = tokenizer.encode(self.text[after:])
|
|
||||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
|
||||||
print(words, file=sys.stderr)
|
|
||||||
print(wids, file=sys.stderr)
|
|
||||||
if not words:
|
|
||||||
return 0
|
|
||||||
self.text = self.text[:after] + "".join(words[num:])
|
|
||||||
return sum(len(wi) for wi in wids[:num])
|
|
||||||
|
|
||||||
def append_token_ids(self, token_ids):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
ids = tokenizer.encode(self.text)
|
|
||||||
return tokenizer.split_to_word_tokens(ids)
|
|
||||||
60
whisperlivekit/trail_repetition.py
Normal file
60
whisperlivekit/trail_repetition.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Sequence, Callable, Any, Optional, Dict
|
||||||
|
|
||||||
|
def _detect_tail_repetition(
|
||||||
|
seq: Sequence[Any],
|
||||||
|
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||||
|
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||||
|
max_tail: int = 300, # search window from the end for speed
|
||||||
|
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
vals = [key(x) for x in seq][-max_tail:]
|
||||||
|
n = len(vals)
|
||||||
|
best = None
|
||||||
|
|
||||||
|
# try every possible block length
|
||||||
|
for b in range(min_block, n // 2 + 1):
|
||||||
|
block = vals[-b:]
|
||||||
|
# count how many times this block repeats contiguously at the very end
|
||||||
|
count, i = 0, n
|
||||||
|
while i - b >= 0 and vals[i - b:i] == block:
|
||||||
|
count += 1
|
||||||
|
i -= b
|
||||||
|
|
||||||
|
if count >= 2:
|
||||||
|
cand = {
|
||||||
|
"block_size": b,
|
||||||
|
"count": count,
|
||||||
|
"start_index": len(seq) - count * b, # in original seq
|
||||||
|
"end_index": len(seq),
|
||||||
|
}
|
||||||
|
if (best is None or
|
||||||
|
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||||
|
(prefer == "smallest" and b < best["block_size"])):
|
||||||
|
best = cand
|
||||||
|
return best
|
||||||
|
|
||||||
|
def trim_tail_repetition(
|
||||||
|
seq: Sequence[Any],
|
||||||
|
key: Callable[[Any], Any] = lambda x: x,
|
||||||
|
min_block: int = 1,
|
||||||
|
max_tail: int = 300,
|
||||||
|
prefer: str = "longest",
|
||||||
|
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a new sequence with repeated tail trimmed.
|
||||||
|
keep=1 -> keep a single copy of the repeated block.
|
||||||
|
keep=0 -> remove all copies of the repeated block.
|
||||||
|
"""
|
||||||
|
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||||
|
if not rep:
|
||||||
|
return seq, False # nothing to trim
|
||||||
|
|
||||||
|
b, c = rep["block_size"], rep["count"]
|
||||||
|
if keep < 0:
|
||||||
|
keep = 0
|
||||||
|
if keep >= c:
|
||||||
|
return seq, False # nothing to trim (already <= keep copies)
|
||||||
|
# new length = total - (copies_to_remove * block_size)
|
||||||
|
new_len = len(seq) - (c - keep) * b
|
||||||
|
return seq[:new_len], True
|
||||||
62
whisperlivekit/warmup.py
Normal file
62
whisperlivekit/warmup.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_file(warmup_file=None, timeout=5):
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
if warmup_file is None:
|
||||||
|
# Download JFK sample if not already present
|
||||||
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||||
|
|
||||||
|
if not os.path.exists(warmup_file):
|
||||||
|
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||||
|
print(f"Downloading warmup file from {jfk_url}")
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
import socket
|
||||||
|
|
||||||
|
original_timeout = socket.getdefaulttimeout()
|
||||||
|
socket.setdefaulttimeout(timeout)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve(jfk_url, warmup_file)
|
||||||
|
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
||||||
|
except (urllib.error.URLError, socket.timeout) as e:
|
||||||
|
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
socket.setdefaulttimeout(original_timeout)
|
||||||
|
elif not warmup_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
|
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load audio file: {e}")
|
||||||
|
return False
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||||
|
"""
|
||||||
|
Warmup the ASR model by transcribing a short audio file.
|
||||||
|
"""
|
||||||
|
audio = load_file(warmup_file=None, timeout=5)
|
||||||
|
asr.transcribe(audio)
|
||||||
|
logger.info("ASR model is warmed up")
|
||||||
|
|
||||||
|
def warmup_online(online, warmup_file=None, timeout=5):
|
||||||
|
audio = load_file(warmup_file=None, timeout=5)
|
||||||
|
online.warmup(audio)
|
||||||
|
logger.warning("ASR is warmed up")
|
||||||
402
whisperlivekit/web/live_transcription.css
Normal file
402
whisperlivekit/web/live_transcription.css
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
:root {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
:root:not([data-theme="light"]) {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="dark"] {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="light"] {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||||
|
margin: 20px;
|
||||||
|
text-align: center;
|
||||||
|
background-color: var(--bg);
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Record button */
|
||||||
|
#recordButton {
|
||||||
|
width: 50px;
|
||||||
|
height: 50px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
border: 1px solid var(--button-border);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording {
|
||||||
|
width: 180px;
|
||||||
|
border-radius: 40px;
|
||||||
|
justify-content: flex-start;
|
||||||
|
padding-left: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton:active {
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
.shape-container {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.shape {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
background-color: rgb(209, 61, 53);
|
||||||
|
border-radius: 50%;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton:disabled .shape {
|
||||||
|
background-color: #6e6d6d;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .shape {
|
||||||
|
border-radius: 5px;
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Recording elements */
|
||||||
|
.recording-info {
|
||||||
|
display: none;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: 15px;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .recording-info {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
.wave-container {
|
||||||
|
width: 60px;
|
||||||
|
height: 30px;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
#waveCanvas {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timer {
|
||||||
|
font-size: 14px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: var(--text);
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#status {
|
||||||
|
margin-top: 20px;
|
||||||
|
font-size: 16px;
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Settings */
|
||||||
|
.settings-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.field {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector,
|
||||||
|
#websocketInput,
|
||||||
|
#themeSelector {
|
||||||
|
font-size: 16px;
|
||||||
|
padding: 5px 8px;
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
color: var(--text);
|
||||||
|
max-height: 34px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#websocketInput {
|
||||||
|
width: 220px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector:focus,
|
||||||
|
#websocketInput:focus,
|
||||||
|
#themeSelector:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #007bff;
|
||||||
|
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
label {
|
||||||
|
font-size: 13px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.ws-default {
|
||||||
|
font-size: 12px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Segmented pill control for Theme */
|
||||||
|
.segmented {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: stretch;
|
||||||
|
border: 1px solid var(--button-border);
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
border-radius: 999px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"] {
|
||||||
|
position: absolute;
|
||||||
|
opacity: 0;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.theme-selector-container {
|
||||||
|
position: absolute;
|
||||||
|
top: 20px;
|
||||||
|
right: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
padding: 6px 12px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--muted);
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
transition: background-color 0.2s ease, color 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label span {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label:hover span {
|
||||||
|
display: inline;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label:hover {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented img {
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"]:checked + label {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"]:focus-visible + label,
|
||||||
|
.segmented input[type="radio"]:focus + label {
|
||||||
|
outline: 2px solid #007bff;
|
||||||
|
outline-offset: 2px;
|
||||||
|
border-radius: 999px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Transcript area */
|
||||||
|
#linesTranscript {
|
||||||
|
margin: 20px auto;
|
||||||
|
max-width: 700px;
|
||||||
|
text-align: left;
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript p {
|
||||||
|
margin: 0px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript strong {
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#speaker {
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_diarization {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
border-radius: 8px 8px 8px 8px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
margin-left: 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: var(--label-dia-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_transcription {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
border-radius: 8px 8px 8px 8px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
margin-left: 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: var(--label-trans-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#timeInfo {
|
||||||
|
color: var(--muted);
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.textcontent {
|
||||||
|
font-size: 16px;
|
||||||
|
padding-left: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
margin-top: 1px;
|
||||||
|
padding-top: 5px;
|
||||||
|
border-radius: 0px 0px 0px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_diarization {
|
||||||
|
color: var(--label-dia-text);
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_transcription {
|
||||||
|
color: #7474748c;
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.spinner {
|
||||||
|
display: inline-block;
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border: 2px solid var(--spinner-border);
|
||||||
|
border-top: 2px solid var(--spinner-top);
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: spin 0.7s linear infinite;
|
||||||
|
vertical-align: middle;
|
||||||
|
margin-bottom: 2px;
|
||||||
|
margin-right: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
to {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.silence {
|
||||||
|
color: var(--muted);
|
||||||
|
background-color: var(--silence-bg);
|
||||||
|
font-size: 13px;
|
||||||
|
border-radius: 30px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading {
|
||||||
|
color: var(--muted);
|
||||||
|
background-color: var(--loading-bg);
|
||||||
|
border-radius: 8px 8px 8px 0px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
@@ -1,682 +1,61 @@
|
|||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
|
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Audio Transcription</title>
|
<title>WhisperLiveKit</title>
|
||||||
<style>
|
<link rel="stylesheet" href="/web/live_transcription.css" />
|
||||||
body {
|
|
||||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
|
||||||
margin: 20px;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton {
|
|
||||||
width: 50px;
|
|
||||||
height: 50px;
|
|
||||||
border: none;
|
|
||||||
border-radius: 50%;
|
|
||||||
background-color: white;
|
|
||||||
cursor: pointer;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
border: 1px solid rgb(233, 233, 233);
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
position: relative;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording {
|
|
||||||
width: 180px;
|
|
||||||
border-radius: 40px;
|
|
||||||
justify-content: flex-start;
|
|
||||||
padding-left: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton:active {
|
|
||||||
transform: scale(0.95);
|
|
||||||
}
|
|
||||||
|
|
||||||
.shape-container {
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.shape {
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
background-color: rgb(209, 61, 53);
|
|
||||||
border-radius: 50%;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton:disabled .shape {
|
|
||||||
background-color: #6e6d6d;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording .shape {
|
|
||||||
border-radius: 5px;
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Recording elements */
|
|
||||||
.recording-info {
|
|
||||||
display: none;
|
|
||||||
align-items: center;
|
|
||||||
margin-left: 15px;
|
|
||||||
flex-grow: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording .recording-info {
|
|
||||||
display: flex;
|
|
||||||
}
|
|
||||||
|
|
||||||
.wave-container {
|
|
||||||
width: 60px;
|
|
||||||
height: 30px;
|
|
||||||
position: relative;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
#waveCanvas {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
}
|
|
||||||
|
|
||||||
.timer {
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
color: #333;
|
|
||||||
margin-left: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#status {
|
|
||||||
margin-top: 20px;
|
|
||||||
font-size: 16px;
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
.settings-container {
|
|
||||||
display: flex;
|
|
||||||
justify-content: center;
|
|
||||||
align-items: center;
|
|
||||||
gap: 15px;
|
|
||||||
margin-top: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.settings {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: flex-start;
|
|
||||||
gap: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chunkSelector,
|
|
||||||
#websocketInput {
|
|
||||||
font-size: 16px;
|
|
||||||
padding: 5px;
|
|
||||||
border-radius: 5px;
|
|
||||||
border: 1px solid #ddd;
|
|
||||||
background-color: #ffffff;
|
|
||||||
max-height: 30px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#websocketInput {
|
|
||||||
width: 200px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chunkSelector:focus,
|
|
||||||
#websocketInput:focus {
|
|
||||||
outline: none;
|
|
||||||
border-color: #007bff;
|
|
||||||
}
|
|
||||||
|
|
||||||
label {
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Speaker-labeled transcript area */
|
|
||||||
#linesTranscript {
|
|
||||||
margin: 20px auto;
|
|
||||||
max-width: 700px;
|
|
||||||
text-align: left;
|
|
||||||
font-size: 16px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#linesTranscript p {
|
|
||||||
margin: 0px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#linesTranscript strong {
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
#speaker {
|
|
||||||
border: 1px solid rgb(229, 229, 229);
|
|
||||||
border-radius: 100px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
}
|
|
||||||
.label_diarization {
|
|
||||||
background-color: #ffffff66;
|
|
||||||
border-radius: 8px 8px 8px 8px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
margin-left: 10px;
|
|
||||||
display: inline-block;
|
|
||||||
white-space: nowrap;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
color: rgb(134, 134, 134)
|
|
||||||
}
|
|
||||||
|
|
||||||
.label_transcription {
|
|
||||||
background-color: #ffffff66;
|
|
||||||
border-radius: 8px 8px 8px 8px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
display: inline-block;
|
|
||||||
white-space: nowrap;
|
|
||||||
margin-left: 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
color: #000000
|
|
||||||
}
|
|
||||||
|
|
||||||
#timeInfo {
|
|
||||||
color: #666;
|
|
||||||
margin-left: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.textcontent {
|
|
||||||
font-size: 16px;
|
|
||||||
/* margin-left: 10px; */
|
|
||||||
padding-left: 10px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
margin-top: 1px;
|
|
||||||
padding-top: 5px;
|
|
||||||
border-radius: 0px 0px 0px 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.buffer_diarization {
|
|
||||||
color: rgb(134, 134, 134);
|
|
||||||
margin-left: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.buffer_transcription {
|
|
||||||
color: #7474748c;
|
|
||||||
margin-left: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
.spinner {
|
|
||||||
display: inline-block;
|
|
||||||
width: 8px;
|
|
||||||
height: 8px;
|
|
||||||
border: 2px solid #8d8d8d5c;
|
|
||||||
border-top: 2px solid #6c6c6ce5;
|
|
||||||
border-radius: 50%;
|
|
||||||
animation: spin 0.6s linear infinite;
|
|
||||||
vertical-align: middle;
|
|
||||||
margin-bottom: 2px;
|
|
||||||
margin-right: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes spin {
|
|
||||||
to {
|
|
||||||
transform: rotate(360deg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.silence {
|
|
||||||
color: #666;
|
|
||||||
background-color: #f3f3f3;
|
|
||||||
font-size: 13px;
|
|
||||||
border-radius: 30px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loading {
|
|
||||||
color: #666;
|
|
||||||
background-color: #ff4d4d0f;
|
|
||||||
border-radius: 8px 8px 8px 0px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body>
|
<body>
|
||||||
|
<div class="settings-container">
|
||||||
<div class="settings-container">
|
<button id="recordButton">
|
||||||
<button id="recordButton">
|
<div class="shape-container">
|
||||||
<div class="shape-container">
|
<div class="shape"></div>
|
||||||
<div class="shape"></div>
|
</div>
|
||||||
</div>
|
<div class="recording-info">
|
||||||
<div class="recording-info">
|
<div class="wave-container">
|
||||||
<div class="wave-container">
|
<canvas id="waveCanvas"></canvas>
|
||||||
<canvas id="waveCanvas"></canvas>
|
|
||||||
</div>
|
|
||||||
<div class="timer">00:00</div>
|
|
||||||
</div>
|
|
||||||
</button>
|
|
||||||
<div class="settings">
|
|
||||||
<div>
|
|
||||||
<label for="chunkSelector">Chunk size (ms):</label>
|
|
||||||
<select id="chunkSelector">
|
|
||||||
<option value="500">500 ms</option>
|
|
||||||
<option value="1000" selected>1000 ms</option>
|
|
||||||
<option value="2000">2000 ms</option>
|
|
||||||
<option value="3000">3000 ms</option>
|
|
||||||
<option value="4000">4000 ms</option>
|
|
||||||
<option value="5000">5000 ms</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="websocketInput">WebSocket URL:</label>
|
|
||||||
<input id="websocketInput" type="text" />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
<div class="timer">00:00</div>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div class="settings">
|
||||||
|
<div class="field">
|
||||||
|
<label for="websocketInput">WebSocket URL</label>
|
||||||
|
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<p id="status"></p>
|
<div class="theme-selector-container">
|
||||||
|
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||||
|
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||||
|
<label for="theme-system" title="System">
|
||||||
|
<img src="/web/src/system_mode.svg" alt="" />
|
||||||
|
<span>System</span>
|
||||||
|
</label>
|
||||||
|
|
||||||
<!-- Speaker-labeled transcript -->
|
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||||
<div id="linesTranscript"></div>
|
<label for="theme-light" title="Light">
|
||||||
|
<img src="/web/src/light_mode.svg" alt="" />
|
||||||
|
<span>Light</span>
|
||||||
|
</label>
|
||||||
|
|
||||||
<script>
|
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||||
let isRecording = false;
|
<label for="theme-dark" title="Dark">
|
||||||
let websocket = null;
|
<img src="/web/src/dark_mode.svg" alt="" />
|
||||||
let recorder = null;
|
<span>Dark</span>
|
||||||
let chunkDuration = 1000;
|
</label>
|
||||||
let websocketUrl = "ws://localhost:8000/asr";
|
</div>
|
||||||
let userClosing = false;
|
</div>
|
||||||
let startTime = null;
|
|
||||||
let timerInterval = null;
|
|
||||||
let audioContext = null;
|
|
||||||
let analyser = null;
|
|
||||||
let microphone = null;
|
|
||||||
let waveCanvas = document.getElementById("waveCanvas");
|
|
||||||
let waveCtx = waveCanvas.getContext("2d");
|
|
||||||
let animationFrame = null;
|
|
||||||
let waitingForStop = false;
|
|
||||||
let lastReceivedData = null;
|
|
||||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
|
||||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
|
||||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
|
||||||
|
|
||||||
const statusText = document.getElementById("status");
|
<p id="status"></p>
|
||||||
const recordButton = document.getElementById("recordButton");
|
|
||||||
const chunkSelector = document.getElementById("chunkSelector");
|
|
||||||
const websocketInput = document.getElementById("websocketInput");
|
|
||||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
|
||||||
const timerElement = document.querySelector(".timer");
|
|
||||||
|
|
||||||
const host = window.location.hostname || "localhost";
|
<div id="linesTranscript"></div>
|
||||||
const port = window.location.port;
|
|
||||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
|
||||||
const defaultWebSocketUrl = `${protocol}://${host}:${port}/asr`;
|
|
||||||
websocketInput.value = defaultWebSocketUrl;
|
|
||||||
websocketUrl = defaultWebSocketUrl;
|
|
||||||
|
|
||||||
chunkSelector.addEventListener("change", () => {
|
<script src="/web/live_transcription.js"></script>
|
||||||
chunkDuration = parseInt(chunkSelector.value);
|
|
||||||
});
|
|
||||||
|
|
||||||
websocketInput.addEventListener("change", () => {
|
|
||||||
const urlValue = websocketInput.value.trim();
|
|
||||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
|
||||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
websocketUrl = urlValue;
|
|
||||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
|
||||||
});
|
|
||||||
|
|
||||||
function setupWebSocket() {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
try {
|
|
||||||
websocket = new WebSocket(websocketUrl);
|
|
||||||
} catch (error) {
|
|
||||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
|
||||||
reject(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
websocket.onopen = () => {
|
|
||||||
statusText.textContent = "Connected to server.";
|
|
||||||
resolve();
|
|
||||||
};
|
|
||||||
|
|
||||||
websocket.onclose = () => {
|
|
||||||
if (userClosing) {
|
|
||||||
if (waitingForStop) {
|
|
||||||
statusText.textContent = "Processing finalized or connection closed.";
|
|
||||||
if (lastReceivedData) {
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lastReceivedData.lines || [],
|
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
0, 0, true // isFinalizing = true
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If ready_to_stop was received, statusText is already "Finished processing..."
|
|
||||||
// and waitingForStop is false.
|
|
||||||
} else {
|
|
||||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
|
||||||
if (isRecording) {
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
isRecording = false;
|
|
||||||
waitingForStop = false;
|
|
||||||
userClosing = false;
|
|
||||||
lastReceivedData = null;
|
|
||||||
websocket = null;
|
|
||||||
updateUI();
|
|
||||||
};
|
|
||||||
|
|
||||||
websocket.onerror = () => {
|
|
||||||
statusText.textContent = "Error connecting to WebSocket.";
|
|
||||||
reject(new Error("Error connecting to WebSocket"));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Handle messages from server
|
|
||||||
websocket.onmessage = (event) => {
|
|
||||||
const data = JSON.parse(event.data);
|
|
||||||
|
|
||||||
// Check for status messages
|
|
||||||
if (data.type === "ready_to_stop") {
|
|
||||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
|
||||||
waitingForStop = false;
|
|
||||||
|
|
||||||
if (lastReceivedData) {
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lastReceivedData.lines || [],
|
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
0, // No more lag
|
|
||||||
0, // No more lag
|
|
||||||
true // isFinalizing = true
|
|
||||||
);
|
|
||||||
}
|
|
||||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
|
||||||
recordButton.disabled = false;
|
|
||||||
|
|
||||||
if (websocket) {
|
|
||||||
websocket.close(); // will trigger onclose
|
|
||||||
// websocket = null; // onclose handle setting websocket to null
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
lastReceivedData = data;
|
|
||||||
|
|
||||||
// Handle normal transcription updates
|
|
||||||
const {
|
|
||||||
lines = [],
|
|
||||||
buffer_transcription = "",
|
|
||||||
buffer_diarization = "",
|
|
||||||
remaining_time_transcription = 0,
|
|
||||||
remaining_time_diarization = 0,
|
|
||||||
status = "active_transcription"
|
|
||||||
} = data;
|
|
||||||
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lines,
|
|
||||||
buffer_diarization,
|
|
||||||
buffer_transcription,
|
|
||||||
remaining_time_diarization,
|
|
||||||
remaining_time_transcription,
|
|
||||||
false,
|
|
||||||
status
|
|
||||||
);
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
|
||||||
if (current_status === "no_audio_detected") {
|
|
||||||
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const linesHtml = lines.map((item, idx) => {
|
|
||||||
let timeInfo = "";
|
|
||||||
if (item.beg !== undefined && item.end !== undefined) {
|
|
||||||
timeInfo = ` ${item.beg} - ${item.end}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let speakerLabel = "";
|
|
||||||
if (item.speaker === -2) {
|
|
||||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
} else if (item.speaker == 0 && !isFinalizing) {
|
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
|
||||||
} else if (item.speaker == -1) {
|
|
||||||
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
|
||||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let currentLineText = item.text || "";
|
|
||||||
|
|
||||||
if (idx === lines.length - 1) {
|
|
||||||
if (!isFinalizing) {
|
|
||||||
if (remaining_time_transcription > 0) {
|
|
||||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
|
||||||
}
|
|
||||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
|
||||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buffer_diarization) {
|
|
||||||
if (isFinalizing) {
|
|
||||||
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
|
||||||
} else {
|
|
||||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (buffer_transcription) {
|
|
||||||
if (isFinalizing) {
|
|
||||||
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
|
|
||||||
} else {
|
|
||||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
|
||||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
|
||||||
: `<p>${speakerLabel}<br/></p>`;
|
|
||||||
}).join("");
|
|
||||||
|
|
||||||
linesTranscriptDiv.innerHTML = linesHtml;
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateTimer() {
|
|
||||||
if (!startTime) return;
|
|
||||||
|
|
||||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
|
||||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
|
||||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
|
||||||
timerElement.textContent = `${minutes}:${seconds}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function drawWaveform() {
|
|
||||||
if (!analyser) return;
|
|
||||||
|
|
||||||
const bufferLength = analyser.frequencyBinCount;
|
|
||||||
const dataArray = new Uint8Array(bufferLength);
|
|
||||||
analyser.getByteTimeDomainData(dataArray);
|
|
||||||
|
|
||||||
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
|
||||||
waveCtx.lineWidth = 1;
|
|
||||||
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
|
|
||||||
waveCtx.beginPath();
|
|
||||||
|
|
||||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
|
||||||
let x = 0;
|
|
||||||
|
|
||||||
for (let i = 0; i < bufferLength; i++) {
|
|
||||||
const v = dataArray[i] / 128.0;
|
|
||||||
const y = v * (waveCanvas.height / (window.devicePixelRatio || 1)) / 2;
|
|
||||||
|
|
||||||
if (i === 0) {
|
|
||||||
waveCtx.moveTo(x, y);
|
|
||||||
} else {
|
|
||||||
waveCtx.lineTo(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
x += sliceWidth;
|
|
||||||
}
|
|
||||||
|
|
||||||
waveCtx.lineTo(waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1) / 2);
|
|
||||||
waveCtx.stroke();
|
|
||||||
|
|
||||||
animationFrame = requestAnimationFrame(drawWaveform);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function startRecording() {
|
|
||||||
try {
|
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
|
||||||
|
|
||||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
|
||||||
analyser = audioContext.createAnalyser();
|
|
||||||
analyser.fftSize = 256;
|
|
||||||
microphone = audioContext.createMediaStreamSource(stream);
|
|
||||||
microphone.connect(analyser);
|
|
||||||
|
|
||||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
|
||||||
recorder.ondataavailable = (e) => {
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
websocket.send(e.data);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
recorder.start(chunkDuration);
|
|
||||||
|
|
||||||
startTime = Date.now();
|
|
||||||
timerInterval = setInterval(updateTimer, 1000);
|
|
||||||
drawWaveform();
|
|
||||||
|
|
||||||
isRecording = true;
|
|
||||||
updateUI();
|
|
||||||
} catch (err) {
|
|
||||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function stopRecording() {
|
|
||||||
userClosing = true;
|
|
||||||
waitingForStop = true;
|
|
||||||
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
// Send empty audio buffer as stop signal
|
|
||||||
const emptyBlob = new Blob([], { type: 'audio/webm' });
|
|
||||||
websocket.send(emptyBlob);
|
|
||||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recorder) {
|
|
||||||
recorder.stop();
|
|
||||||
recorder = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (microphone) {
|
|
||||||
microphone.disconnect();
|
|
||||||
microphone = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (analyser) {
|
|
||||||
analyser = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (audioContext && audioContext.state !== 'closed') {
|
|
||||||
try {
|
|
||||||
audioContext.close();
|
|
||||||
} catch (e) {
|
|
||||||
console.warn("Could not close audio context:", e);
|
|
||||||
}
|
|
||||||
audioContext = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (animationFrame) {
|
|
||||||
cancelAnimationFrame(animationFrame);
|
|
||||||
animationFrame = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (timerInterval) {
|
|
||||||
clearInterval(timerInterval);
|
|
||||||
timerInterval = null;
|
|
||||||
}
|
|
||||||
timerElement.textContent = "00:00";
|
|
||||||
startTime = null;
|
|
||||||
|
|
||||||
|
|
||||||
isRecording = false;
|
|
||||||
updateUI();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function toggleRecording() {
|
|
||||||
if (!isRecording) {
|
|
||||||
if (waitingForStop) {
|
|
||||||
console.log("Waiting for stop, early return");
|
|
||||||
return; // Early return, UI is already updated
|
|
||||||
}
|
|
||||||
console.log("Connecting to WebSocket");
|
|
||||||
try {
|
|
||||||
// If we have an active WebSocket that's still processing, just restart audio capture
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
await startRecording();
|
|
||||||
} else {
|
|
||||||
// If no active WebSocket or it's closed, create new one
|
|
||||||
await setupWebSocket();
|
|
||||||
await startRecording();
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
console.log("Stopping recording");
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateUI() {
|
|
||||||
recordButton.classList.toggle("recording", isRecording);
|
|
||||||
recordButton.disabled = waitingForStop;
|
|
||||||
|
|
||||||
if (waitingForStop) {
|
|
||||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
|
||||||
statusText.textContent = "Please wait for processing to complete...";
|
|
||||||
}
|
|
||||||
} else if (isRecording) {
|
|
||||||
statusText.textContent = "Recording...";
|
|
||||||
} else {
|
|
||||||
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
|
||||||
statusText.textContent !== "Processing finalized or connection closed.") {
|
|
||||||
statusText.textContent = "Click to start transcription";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!waitingForStop) {
|
|
||||||
recordButton.disabled = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
recordButton.addEventListener("click", toggleRecording);
|
|
||||||
</script>
|
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
513
whisperlivekit/web/live_transcription.js
Normal file
513
whisperlivekit/web/live_transcription.js
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
/* Theme, WebSocket, recording, rendering logic extracted from inline script and adapted for segmented theme control and WS caption */
|
||||||
|
|
||||||
|
let isRecording = false;
|
||||||
|
let websocket = null;
|
||||||
|
let recorder = null;
|
||||||
|
let chunkDuration = 100;
|
||||||
|
let websocketUrl = "ws://localhost:8000/asr";
|
||||||
|
let userClosing = false;
|
||||||
|
let wakeLock = null;
|
||||||
|
let startTime = null;
|
||||||
|
let timerInterval = null;
|
||||||
|
let audioContext = null;
|
||||||
|
let analyser = null;
|
||||||
|
let microphone = null;
|
||||||
|
let waveCanvas = document.getElementById("waveCanvas");
|
||||||
|
let waveCtx = waveCanvas.getContext("2d");
|
||||||
|
let animationFrame = null;
|
||||||
|
let waitingForStop = false;
|
||||||
|
let lastReceivedData = null;
|
||||||
|
let lastSignature = null;
|
||||||
|
|
||||||
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
|
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||||
|
|
||||||
|
const statusText = document.getElementById("status");
|
||||||
|
const recordButton = document.getElementById("recordButton");
|
||||||
|
const chunkSelector = document.getElementById("chunkSelector");
|
||||||
|
const websocketInput = document.getElementById("websocketInput");
|
||||||
|
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||||
|
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||||
|
const timerElement = document.querySelector(".timer");
|
||||||
|
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||||
|
|
||||||
|
function getWaveStroke() {
|
||||||
|
const styles = getComputedStyle(document.documentElement);
|
||||||
|
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||||
|
return v || "#000";
|
||||||
|
}
|
||||||
|
|
||||||
|
let waveStroke = getWaveStroke();
|
||||||
|
function updateWaveStroke() {
|
||||||
|
waveStroke = getWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyTheme(pref) {
|
||||||
|
if (pref === "light") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "light");
|
||||||
|
} else if (pref === "dark") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "dark");
|
||||||
|
} else {
|
||||||
|
document.documentElement.removeAttribute("data-theme");
|
||||||
|
}
|
||||||
|
updateWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persisted theme preference
|
||||||
|
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||||
|
applyTheme(savedThemePref);
|
||||||
|
if (themeRadios.length) {
|
||||||
|
themeRadios.forEach((r) => {
|
||||||
|
r.checked = r.value === savedThemePref;
|
||||||
|
r.addEventListener("change", () => {
|
||||||
|
if (r.checked) {
|
||||||
|
localStorage.setItem("themePreference", r.value);
|
||||||
|
applyTheme(r.value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// React to OS theme changes when in "system" mode
|
||||||
|
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||||
|
const handleOsThemeChange = () => {
|
||||||
|
const pref = localStorage.getItem("themePreference") || "system";
|
||||||
|
if (pref === "system") updateWaveStroke();
|
||||||
|
};
|
||||||
|
if (darkMq && darkMq.addEventListener) {
|
||||||
|
darkMq.addEventListener("change", handleOsThemeChange);
|
||||||
|
} else if (darkMq && darkMq.addListener) {
|
||||||
|
// deprecated, but included for Safari compatibility
|
||||||
|
darkMq.addListener(handleOsThemeChange);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
function fmt1(x) {
|
||||||
|
const n = Number(x);
|
||||||
|
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default WebSocket URL computation
|
||||||
|
const host = window.location.hostname || "localhost";
|
||||||
|
const port = window.location.port;
|
||||||
|
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||||
|
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||||
|
|
||||||
|
// Populate default caption and input
|
||||||
|
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||||
|
websocketInput.value = defaultWebSocketUrl;
|
||||||
|
websocketUrl = defaultWebSocketUrl;
|
||||||
|
|
||||||
|
// Optional chunk selector (guard for presence)
|
||||||
|
if (chunkSelector) {
|
||||||
|
chunkSelector.addEventListener("change", () => {
|
||||||
|
chunkDuration = parseInt(chunkSelector.value);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket input change handling
|
||||||
|
websocketInput.addEventListener("change", () => {
|
||||||
|
const urlValue = websocketInput.value.trim();
|
||||||
|
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
websocketUrl = urlValue;
|
||||||
|
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupWebSocket() {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
try {
|
||||||
|
websocket = new WebSocket(websocketUrl);
|
||||||
|
} catch (error) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||||
|
reject(error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
websocket.onopen = () => {
|
||||||
|
statusText.textContent = "Connected to server.";
|
||||||
|
resolve();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
if (userClosing) {
|
||||||
|
if (waitingForStop) {
|
||||||
|
statusText.textContent = "Processing finalized or connection closed.";
|
||||||
|
if (lastReceivedData) {
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lastReceivedData.lines || [],
|
||||||
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isRecording = false;
|
||||||
|
waitingForStop = false;
|
||||||
|
userClosing = false;
|
||||||
|
lastReceivedData = null;
|
||||||
|
websocket = null;
|
||||||
|
updateUI();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = () => {
|
||||||
|
statusText.textContent = "Error connecting to WebSocket.";
|
||||||
|
reject(new Error("Error connecting to WebSocket"));
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
if (data.type === "ready_to_stop") {
|
||||||
|
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||||
|
waitingForStop = false;
|
||||||
|
|
||||||
|
if (lastReceivedData) {
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lastReceivedData.lines || [],
|
||||||
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
}
|
||||||
|
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||||
|
recordButton.disabled = false;
|
||||||
|
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
lastReceivedData = data;
|
||||||
|
|
||||||
|
const {
|
||||||
|
lines = [],
|
||||||
|
buffer_transcription = "",
|
||||||
|
buffer_diarization = "",
|
||||||
|
remaining_time_transcription = 0,
|
||||||
|
remaining_time_diarization = 0,
|
||||||
|
status = "active_transcription",
|
||||||
|
} = data;
|
||||||
|
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lines,
|
||||||
|
buffer_diarization,
|
||||||
|
buffer_transcription,
|
||||||
|
remaining_time_diarization,
|
||||||
|
remaining_time_transcription,
|
||||||
|
false,
|
||||||
|
status
|
||||||
|
);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderLinesWithBuffer(
|
||||||
|
lines,
|
||||||
|
buffer_diarization,
|
||||||
|
buffer_transcription,
|
||||||
|
remaining_time_diarization,
|
||||||
|
remaining_time_transcription,
|
||||||
|
isFinalizing = false,
|
||||||
|
current_status = "active_transcription"
|
||||||
|
) {
|
||||||
|
if (current_status === "no_audio_detected") {
|
||||||
|
linesTranscriptDiv.innerHTML =
|
||||||
|
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||||
|
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||||
|
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||||
|
const signature = JSON.stringify({
|
||||||
|
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, beg: it.beg, end: it.end })),
|
||||||
|
buffer_transcription: buffer_transcription || "",
|
||||||
|
buffer_diarization: buffer_diarization || "",
|
||||||
|
status: current_status,
|
||||||
|
showLoading,
|
||||||
|
showTransLag,
|
||||||
|
showDiaLag,
|
||||||
|
isFinalizing: !!isFinalizing,
|
||||||
|
});
|
||||||
|
if (lastSignature === signature) {
|
||||||
|
const t = document.querySelector(".lag-transcription-value");
|
||||||
|
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||||
|
const d = document.querySelector(".lag-diarization-value");
|
||||||
|
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||||
|
const ld = document.querySelector(".loading-diarization-value");
|
||||||
|
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lastSignature = signature;
|
||||||
|
|
||||||
|
const linesHtml = (lines || [])
|
||||||
|
.map((item, idx) => {
|
||||||
|
let timeInfo = "";
|
||||||
|
if (item.beg !== undefined && item.end !== undefined) {
|
||||||
|
timeInfo = ` ${item.beg} - ${item.end}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
let speakerLabel = "";
|
||||||
|
if (item.speaker === -2) {
|
||||||
|
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
} else if (item.speaker == 0 && !isFinalizing) {
|
||||||
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||||
|
remaining_time_diarization
|
||||||
|
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||||
|
} else if (item.speaker !== 0) {
|
||||||
|
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
let currentLineText = item.text || "";
|
||||||
|
|
||||||
|
if (idx === lines.length - 1) {
|
||||||
|
if (!isFinalizing && item.speaker !== -2) {
|
||||||
|
if (remaining_time_transcription > 0) {
|
||||||
|
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||||
|
remaining_time_transcription
|
||||||
|
)}</span>s</span></span>`;
|
||||||
|
}
|
||||||
|
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||||
|
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||||
|
remaining_time_diarization
|
||||||
|
)}</span>s</span></span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buffer_diarization) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText +=
|
||||||
|
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (buffer_transcription) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText +=
|
||||||
|
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||||
|
buffer_transcription.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||||
|
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||||
|
: `<p>${speakerLabel}<br/></p>`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
linesTranscriptDiv.innerHTML = linesHtml;
|
||||||
|
window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" });
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTimer() {
|
||||||
|
if (!startTime) return;
|
||||||
|
|
||||||
|
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||||
|
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||||
|
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||||
|
timerElement.textContent = `${minutes}:${seconds}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function drawWaveform() {
|
||||||
|
if (!analyser) return;
|
||||||
|
|
||||||
|
const bufferLength = analyser.frequencyBinCount;
|
||||||
|
const dataArray = new Uint8Array(bufferLength);
|
||||||
|
analyser.getByteTimeDomainData(dataArray);
|
||||||
|
|
||||||
|
waveCtx.clearRect(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||||
|
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||||
|
);
|
||||||
|
waveCtx.lineWidth = 1;
|
||||||
|
waveCtx.strokeStyle = waveStroke;
|
||||||
|
waveCtx.beginPath();
|
||||||
|
|
||||||
|
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||||
|
let x = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < bufferLength; i++) {
|
||||||
|
const v = dataArray[i] / 128.0;
|
||||||
|
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||||
|
|
||||||
|
if (i === 0) {
|
||||||
|
waveCtx.moveTo(x, y);
|
||||||
|
} else {
|
||||||
|
waveCtx.lineTo(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
x += sliceWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
waveCtx.lineTo(
|
||||||
|
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||||
|
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||||
|
);
|
||||||
|
waveCtx.stroke();
|
||||||
|
|
||||||
|
animationFrame = requestAnimationFrame(drawWaveform);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
try {
|
||||||
|
try {
|
||||||
|
wakeLock = await navigator.wakeLock.request("screen");
|
||||||
|
} catch (err) {
|
||||||
|
console.log("Error acquiring wake lock.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
analyser.fftSize = 256;
|
||||||
|
microphone = audioContext.createMediaStreamSource(stream);
|
||||||
|
microphone.connect(analyser);
|
||||||
|
|
||||||
|
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||||
|
recorder.ondataavailable = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(e.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recorder.start(chunkDuration);
|
||||||
|
|
||||||
|
startTime = Date.now();
|
||||||
|
timerInterval = setInterval(updateTimer, 1000);
|
||||||
|
drawWaveform();
|
||||||
|
|
||||||
|
isRecording = true;
|
||||||
|
updateUI();
|
||||||
|
} catch (err) {
|
||||||
|
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function stopRecording() {
|
||||||
|
if (wakeLock) {
|
||||||
|
try {
|
||||||
|
await wakeLock.release();
|
||||||
|
} catch (e) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
wakeLock = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
userClosing = true;
|
||||||
|
waitingForStop = true;
|
||||||
|
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||||
|
websocket.send(emptyBlob);
|
||||||
|
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recorder) {
|
||||||
|
recorder.stop();
|
||||||
|
recorder = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (microphone) {
|
||||||
|
microphone.disconnect();
|
||||||
|
microphone = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (analyser) {
|
||||||
|
analyser = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioContext && audioContext.state !== "closed") {
|
||||||
|
try {
|
||||||
|
await audioContext.close();
|
||||||
|
} catch (e) {
|
||||||
|
console.warn("Could not close audio context:", e);
|
||||||
|
}
|
||||||
|
audioContext = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (animationFrame) {
|
||||||
|
cancelAnimationFrame(animationFrame);
|
||||||
|
animationFrame = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timerInterval) {
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
timerInterval = null;
|
||||||
|
}
|
||||||
|
timerElement.textContent = "00:00";
|
||||||
|
startTime = null;
|
||||||
|
|
||||||
|
isRecording = false;
|
||||||
|
updateUI();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function toggleRecording() {
|
||||||
|
if (!isRecording) {
|
||||||
|
if (waitingForStop) {
|
||||||
|
console.log("Waiting for stop, early return");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.log("Connecting to WebSocket");
|
||||||
|
try {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
await startRecording();
|
||||||
|
} else {
|
||||||
|
await setupWebSocket();
|
||||||
|
await startRecording();
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.log("Stopping recording");
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateUI() {
|
||||||
|
recordButton.classList.toggle("recording", isRecording);
|
||||||
|
recordButton.disabled = waitingForStop;
|
||||||
|
|
||||||
|
if (waitingForStop) {
|
||||||
|
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||||
|
statusText.textContent = "Please wait for processing to complete...";
|
||||||
|
}
|
||||||
|
} else if (isRecording) {
|
||||||
|
statusText.textContent = "Recording...";
|
||||||
|
} else {
|
||||||
|
if (
|
||||||
|
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||||
|
statusText.textContent !== "Processing finalized or connection closed."
|
||||||
|
) {
|
||||||
|
statusText.textContent = "Click to start transcription";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!waitingForStop) {
|
||||||
|
recordButton.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recordButton.addEventListener("click", toggleRecording);
|
||||||
1
whisperlivekit/web/src/dark_mode.svg
Normal file
1
whisperlivekit/web/src/dark_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 493 B |
1
whisperlivekit/web/src/light_mode.svg
Normal file
1
whisperlivekit/web/src/light_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.2 KiB |
1
whisperlivekit/web/src/system_mode.svg
Normal file
1
whisperlivekit/web/src/system_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
@@ -10,4 +10,24 @@ def get_web_interface_html():
|
|||||||
return f.read()
|
return f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading web interface HTML: {e}")
|
logger.error(f"Error loading web interface HTML: {e}")
|
||||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
import uvicorn
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
import pathlib
|
||||||
|
import whisperlivekit.web as webpkg
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||||
|
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(get_web_interface_html())
|
||||||
|
|
||||||
|
uvicorn.run(app=app)
|
||||||
@@ -3,43 +3,10 @@ import logging
|
|||||||
import io
|
import io
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import math
|
import math
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
torch = None
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS = ImportError(
|
|
||||||
"""SimulStreaming dependencies are not available.
|
|
||||||
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]"
|
|
||||||
""")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("⚠️ SimulStreaming dependencies not available. Attempting to download them.")
|
|
||||||
try:
|
|
||||||
from whisperlivekit import download_simulstreaming_backend
|
|
||||||
download_simulstreaming_backend()
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
logger.info("SimulStreaming dependencies downloaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download or import SimulStreaming dependencies: {e}")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
AlignAttConfig = None
|
|
||||||
PaddedAlignAttWhisper = None
|
|
||||||
DEC_PAD = None
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
class ASRBase:
|
class ASRBase:
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
# "" for faster-whisper because it emits the spaces when needed)
|
# "" for faster-whisper because it emits the spaces when needed)
|
||||||
@@ -320,182 +287,4 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.use_vad_opt = True
|
self.use_vad_opt = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
self.task = "translate"
|
self.task = "translate"
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingASR(ASRBase):
|
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
|
||||||
sep = ""
|
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
|
|
||||||
print("*"*80 + f.read() + "*"*80)
|
|
||||||
self.logfile = logfile
|
|
||||||
self.transcribe_kargs = {}
|
|
||||||
self.original_language = None if lan == "auto" else lan
|
|
||||||
|
|
||||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
|
||||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
|
||||||
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
|
||||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
|
||||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
|
||||||
self.beams = kwargs.get('beams', 1)
|
|
||||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
|
||||||
self.task = kwargs.get('task', 'transcribe')
|
|
||||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
|
||||||
self.never_fire = kwargs.get('never_fire', False)
|
|
||||||
self.init_prompt = kwargs.get('init_prompt', None)
|
|
||||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
|
||||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
|
||||||
|
|
||||||
if model_dir is not None:
|
|
||||||
self.model_path = model_dir
|
|
||||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
|
||||||
model_mapping = {
|
|
||||||
'tiny': './tiny.pt',
|
|
||||||
'base': './base.pt',
|
|
||||||
'small': './small.pt',
|
|
||||||
'medium': './medium.pt',
|
|
||||||
'medium.en': './medium.en.pt',
|
|
||||||
'large-v1': './large-v1.pt',
|
|
||||||
'base.en': './base.en.pt',
|
|
||||||
'small.en': './small.en.pt',
|
|
||||||
'tiny.en': './tiny.en.pt',
|
|
||||||
'large-v2': './large-v2.pt',
|
|
||||||
'large-v3': './large-v3.pt',
|
|
||||||
'large': './large-v3.pt'
|
|
||||||
}
|
|
||||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
|
||||||
|
|
||||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
|
||||||
|
|
||||||
# Set up tokenizer for translation if needed
|
|
||||||
if self.task == "translate":
|
|
||||||
self.set_translate_task()
|
|
||||||
|
|
||||||
def load_model(self, modelsize, cache_dir, model_dir):
|
|
||||||
try:
|
|
||||||
cfg = AlignAttConfig(
|
|
||||||
model_path=self.model_path,
|
|
||||||
segment_length=self.segment_length,
|
|
||||||
frame_threshold=self.frame_threshold,
|
|
||||||
language=self.original_language,
|
|
||||||
audio_max_len=self.audio_max_len,
|
|
||||||
audio_min_len=self.audio_min_len,
|
|
||||||
cif_ckpt_path=self.cif_ckpt_path,
|
|
||||||
decoder_type="beam",
|
|
||||||
beam_size=self.beams,
|
|
||||||
task=self.task,
|
|
||||||
never_fire=self.never_fire,
|
|
||||||
init_prompt=self.init_prompt,
|
|
||||||
max_context_tokens=self.max_context_tokens,
|
|
||||||
static_init_prompt=self.static_init_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
|
||||||
model = PaddedAlignAttWhisper(cfg)
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
"""Transcribe audio using SimulStreaming."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
else:
|
|
||||||
audio_tensor = audio
|
|
||||||
|
|
||||||
prompt = init_prompt if init_prompt else (self.init_prompt or "")
|
|
||||||
|
|
||||||
result = self.model.infer(audio_tensor, init_prompt=prompt)
|
|
||||||
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
result = result[result < DEC_PAD]
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming transcription result: {result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming transcription failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def ts_words(self, result) -> List[ASRToken]:
|
|
||||||
"""Convert SimulStreaming result to ASRToken list."""
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
text = self.model.tokenizer.decode(result.cpu().numpy())
|
|
||||||
else:
|
|
||||||
text = str(result)
|
|
||||||
|
|
||||||
if not text or len(text.strip()) == 0:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
|
||||||
words = text.strip().split()
|
|
||||||
if not words:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
|
||||||
#with the SimulStreamingOnlineProcessor
|
|
||||||
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
start_time = i * duration_per_word
|
|
||||||
end_time = (i + 1) * duration_per_word
|
|
||||||
|
|
||||||
token = ASRToken(
|
|
||||||
start=start_time,
|
|
||||||
end=end_time,
|
|
||||||
text=word,
|
|
||||||
probability=1.0
|
|
||||||
)
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def segments_end_ts(self, result) -> List[float]:
|
|
||||||
"""Get segment end timestamps."""
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
num_tokens = len(result)
|
|
||||||
return [num_tokens * 0.1] # rough estimate
|
|
||||||
return [1.0]
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
"""Enable VAD - SimulStreaming has different VAD handling."""
|
|
||||||
logger.info("VAD requested for SimulStreaming - handled internally by the model")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
"""Set up translation task."""
|
|
||||||
try:
|
|
||||||
self.model.tokenizer = tokenizer.get_tokenizer(
|
|
||||||
multilingual=True,
|
|
||||||
language=self.model.cfg.language,
|
|
||||||
num_languages=self.model.model.num_languages,
|
|
||||||
task="translate"
|
|
||||||
)
|
|
||||||
logger.info("SimulStreaming configured for translation task")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def warmup(self, audio, init_prompt=""):
|
|
||||||
"""Warmup the SimulStreaming model."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio = torch.from_numpy(audio).float()
|
|
||||||
self.model.insert_audio(audio)
|
|
||||||
self.model.infer(True)
|
|
||||||
self.model.refresh_segment(complete=True)
|
|
||||||
logger.info("SimulStreaming model warmed up successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"SimulStreaming warmup failed: {e}")
|
|
||||||
@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# simulStreaming imports - we check if the files are here
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
OnlineProcessorInterface = None
|
|
||||||
torch = None
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisBuffer:
|
class HypothesisBuffer:
|
||||||
"""
|
"""
|
||||||
Buffer to store and process ASR hypothesis tokens.
|
Buffer to store and process ASR hypothesis tokens.
|
||||||
@@ -134,6 +122,7 @@ class OnlineASRProcessor:
|
|||||||
self.tokenize = tokenize_method
|
self.tokenize = tokenize_method
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.confidence_validation = confidence_validation
|
self.confidence_validation = confidence_validation
|
||||||
|
self.global_time_offset = 0.0
|
||||||
self.init()
|
self.init()
|
||||||
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||||
@@ -164,6 +153,21 @@ class OnlineASRProcessor:
|
|||||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration, offset):
|
||||||
|
"""
|
||||||
|
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||||
|
"""
|
||||||
|
# if self.transcript_buffer.buffer:
|
||||||
|
# self.committed.extend(self.transcript_buffer.buffer)
|
||||||
|
# self.transcript_buffer.buffer = []
|
||||||
|
|
||||||
|
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
|
||||||
|
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
|
||||||
|
self.insert_audio_chunk(gap_silence)
|
||||||
|
else:
|
||||||
|
self.init(offset=silence_duration + offset)
|
||||||
|
self.global_time_offset += silence_duration
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
def prompt(self) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple: (prompt, context), where:
|
Returns a tuple: (prompt, context), where:
|
||||||
@@ -242,6 +246,9 @@ class OnlineASRProcessor:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||||
)
|
)
|
||||||
|
if self.global_time_offset:
|
||||||
|
for token in committed_tokens:
|
||||||
|
token = token.with_offset(self.global_time_offset)
|
||||||
return committed_tokens, current_audio_processed_upto
|
return committed_tokens, current_audio_processed_upto
|
||||||
|
|
||||||
def chunk_completed_sentence(self):
|
def chunk_completed_sentence(self):
|
||||||
@@ -403,330 +410,3 @@ class OnlineASRProcessor:
|
|||||||
start = None
|
start = None
|
||||||
end = None
|
end = None
|
||||||
return Transcript(start, end, text, probability=probability)
|
return Transcript(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
class VACOnlineASRProcessor:
|
|
||||||
"""
|
|
||||||
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
|
||||||
|
|
||||||
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
|
||||||
and when the system detects a pause in speech (or end of an utterance)
|
|
||||||
it finalizes the utterance immediately.
|
|
||||||
"""
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
|
||||||
self.online_chunk_size = online_chunk_size
|
|
||||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
|
||||||
self.asr = self.online.asr
|
|
||||||
|
|
||||||
# Load a VAD model (e.g. Silero VAD)
|
|
||||||
import torch
|
|
||||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
|
||||||
from .silero_vad_iterator import FixedVADIterator
|
|
||||||
|
|
||||||
self.vac = FixedVADIterator(model)
|
|
||||||
self.logfile = self.online.logfile
|
|
||||||
self.last_input_audio_stream_end_time: float = 0.0
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.online.init()
|
|
||||||
self.vac.reset_states()
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
|
||||||
self.is_currently_final = False
|
|
||||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
|
||||||
self.buffer_offset = 0 # in frames
|
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
|
||||||
return self.online.get_audio_buffer_end_time()
|
|
||||||
|
|
||||||
def clear_buffer(self):
|
|
||||||
self.buffer_offset += len(self.audio_buffer)
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
|
||||||
"""
|
|
||||||
Process an incoming small audio chunk:
|
|
||||||
- run VAD on the chunk,
|
|
||||||
- decide whether to send the audio to the online ASR processor immediately,
|
|
||||||
- and/or to mark the current utterance as finished.
|
|
||||||
"""
|
|
||||||
self.last_input_audio_stream_end_time = audio_stream_end_time
|
|
||||||
res = self.vac(audio)
|
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
|
||||||
|
|
||||||
if res is not None:
|
|
||||||
# VAD returned a result; adjust the frame number
|
|
||||||
frame = list(res.values())[0] - self.buffer_offset
|
|
||||||
if "start" in res and "end" not in res:
|
|
||||||
self.status = "voice"
|
|
||||||
send_audio = self.audio_buffer[frame:]
|
|
||||||
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.clear_buffer()
|
|
||||||
elif "end" in res and "start" not in res:
|
|
||||||
self.status = "nonvoice"
|
|
||||||
send_audio = self.audio_buffer[:frame]
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.is_currently_final = True
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
beg = res["start"] - self.buffer_offset
|
|
||||||
end = res["end"] - self.buffer_offset
|
|
||||||
self.status = "nonvoice"
|
|
||||||
send_audio = self.audio_buffer[beg:end]
|
|
||||||
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.is_currently_final = True
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
if self.status == "voice":
|
|
||||||
self.online.insert_audio_chunk(self.audio_buffer)
|
|
||||||
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
# Keep 1 second worth of audio in case VAD later detects voice,
|
|
||||||
# but trim to avoid unbounded memory usage.
|
|
||||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
|
||||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Depending on the VAD status and the amount of accumulated audio,
|
|
||||||
process the current audio chunk.
|
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
||||||
"""
|
|
||||||
if self.is_currently_final:
|
|
||||||
return self.finish()
|
|
||||||
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
return self.online.process_iter()
|
|
||||||
else:
|
|
||||||
logger.debug("No online update, only VAD")
|
|
||||||
return [], self.last_input_audio_stream_end_time
|
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Finish processing by flushing any remaining text.
|
|
||||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
|
||||||
"""
|
|
||||||
result_tokens, processed_upto = self.online.finish()
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
self.is_currently_final = False
|
|
||||||
return result_tokens, processed_upto
|
|
||||||
|
|
||||||
def get_buffer(self):
|
|
||||||
"""
|
|
||||||
Get the unvalidated buffer in string format.
|
|
||||||
"""
|
|
||||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
asr,
|
|
||||||
tokenize_method: Optional[callable] = None,
|
|
||||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
|
||||||
confidence_validation = False,
|
|
||||||
logfile=sys.stderr,
|
|
||||||
):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError("SimulStreaming dependencies are not available.")
|
|
||||||
|
|
||||||
self.asr = asr
|
|
||||||
self.tokenize = tokenize_method
|
|
||||||
self.logfile = logfile
|
|
||||||
self.confidence_validation = confidence_validation
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
# buffer does not work yet
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
|
||||||
|
|
||||||
def init(self, offset: Optional[float] = None):
|
|
||||||
"""Initialize or reset the processing state."""
|
|
||||||
self.audio_chunks = []
|
|
||||||
self.offset = offset if offset is not None else 0.0
|
|
||||||
self.is_last = False
|
|
||||||
self.beg = self.offset
|
|
||||||
self.end = self.offset
|
|
||||||
self.cumulative_audio_duration = 0.0
|
|
||||||
self.last_audio_stream_end_time = self.offset
|
|
||||||
|
|
||||||
self.committed: List[ASRToken] = []
|
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.buffer_content = ""
|
|
||||||
self.processed_audio_duration = 0.0
|
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the current audio buffer."""
|
|
||||||
return self.end
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
|
||||||
if torch is None:
|
|
||||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
|
||||||
|
|
||||||
# Convert numpy array to torch tensor
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
self.audio_chunks.append(audio_tensor)
|
|
||||||
|
|
||||||
# Update timing
|
|
||||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
|
||||||
self.cumulative_audio_duration += chunk_duration
|
|
||||||
|
|
||||||
if audio_stream_end_time is not None:
|
|
||||||
self.last_audio_stream_end_time = audio_stream_end_time
|
|
||||||
self.end = audio_stream_end_time
|
|
||||||
else:
|
|
||||||
self.end = self.offset + self.cumulative_audio_duration
|
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Returns a tuple: (prompt, context).
|
|
||||||
SimulStreaming handles prompting internally, so we return empty strings.
|
|
||||||
"""
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def get_buffer(self):
|
|
||||||
"""
|
|
||||||
Get the unvalidated buffer content.
|
|
||||||
"""
|
|
||||||
buffer_end = self.end if hasattr(self, 'end') else None
|
|
||||||
return Transcript(
|
|
||||||
start=None,
|
|
||||||
end=buffer_end,
|
|
||||||
text=self.buffer_content,
|
|
||||||
probability=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def timestamped_text(self, tokens, generation):
|
|
||||||
# From the simulstreaming repo. self.model to self.asr.model
|
|
||||||
pr = generation["progress"]
|
|
||||||
if "result" not in generation:
|
|
||||||
split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens)
|
|
||||||
else:
|
|
||||||
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
|
||||||
|
|
||||||
frames = [p["most_attended_frames"][0] for p in pr]
|
|
||||||
tokens = tokens.copy()
|
|
||||||
ret = []
|
|
||||||
for sw,st in zip(split_words,split_tokens):
|
|
||||||
b = None
|
|
||||||
for stt in st:
|
|
||||||
t,f = tokens.pop(0), frames.pop(0)
|
|
||||||
if t != stt:
|
|
||||||
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
|
||||||
if b is None:
|
|
||||||
b = f
|
|
||||||
e = f
|
|
||||||
out = (b*0.02, e*0.02, sw)
|
|
||||||
ret.append(out)
|
|
||||||
logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Process accumulated audio chunks using SimulStreaming.
|
|
||||||
|
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
||||||
"""
|
|
||||||
if not self.audio_chunks:
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
try:
|
|
||||||
# concatenate all audio chunks
|
|
||||||
if len(self.audio_chunks) == 1:
|
|
||||||
audio = self.audio_chunks[0]
|
|
||||||
else:
|
|
||||||
audio = torch.cat(self.audio_chunks, dim=0)
|
|
||||||
|
|
||||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
|
||||||
self.processed_audio_duration += audio_duration
|
|
||||||
|
|
||||||
self.audio_chunks = []
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
|
||||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
|
||||||
|
|
||||||
self.asr.model.insert_audio(audio)
|
|
||||||
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
|
|
||||||
ts_words = self.timestamped_text(tokens, generation_progress)
|
|
||||||
text = self.asr.model.tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
new_tokens = []
|
|
||||||
for ts_word in ts_words:
|
|
||||||
|
|
||||||
start, end, word = ts_word
|
|
||||||
token = ASRToken(
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
text=word,
|
|
||||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
|
||||||
)
|
|
||||||
new_tokens.append(token)
|
|
||||||
self.committed.extend(new_tokens)
|
|
||||||
|
|
||||||
return new_tokens, self.end
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming processing error: {e}")
|
|
||||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
logger.debug("SimulStreaming finish() called")
|
|
||||||
self.is_last = True
|
|
||||||
final_tokens, final_time = self.process_iter()
|
|
||||||
self.is_last = False
|
|
||||||
return final_tokens, final_time
|
|
||||||
|
|
||||||
def concatenate_tokens(
|
|
||||||
self,
|
|
||||||
tokens: List[ASRToken],
|
|
||||||
sep: Optional[str] = None,
|
|
||||||
offset: float = 0
|
|
||||||
) -> Transcript:
|
|
||||||
"""Concatenate tokens into a Transcript object."""
|
|
||||||
sep = sep if sep is not None else self.asr.sep
|
|
||||||
text = sep.join(token.text for token in tokens)
|
|
||||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
|
||||||
if tokens:
|
|
||||||
start = offset + tokens[0].start
|
|
||||||
end = offset + tokens[-1].end
|
|
||||||
else:
|
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
return Transcript(start, end, text, probability=probability)
|
|
||||||
|
|
||||||
def chunk_at(self, time: float):
|
|
||||||
"""
|
|
||||||
useless but kept for compatibility
|
|
||||||
"""
|
|
||||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
|
||||||
"""
|
|
||||||
Create simple sentences.
|
|
||||||
"""
|
|
||||||
if not tokens:
|
|
||||||
return []
|
|
||||||
|
|
||||||
full_text = " ".join(token.text for token in tokens)
|
|
||||||
sentence = Sentence(
|
|
||||||
start=tokens[0].start,
|
|
||||||
end=tokens[-1].end,
|
|
||||||
text=full_text
|
|
||||||
)
|
|
||||||
return [sentence]
|
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import librosa
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE, SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -68,35 +67,7 @@ def backend_factory(args):
|
|||||||
backend = args.backend
|
backend = args.backend
|
||||||
if backend == "openai-api":
|
if backend == "openai-api":
|
||||||
logger.debug("Using OpenAI API.")
|
logger.debug("Using OpenAI API.")
|
||||||
asr = OpenaiApiASR(lan=args.lan)
|
asr = OpenaiApiASR(lan=args.lan)
|
||||||
elif backend == "simulstreaming":
|
|
||||||
logger.debug("Using SimulStreaming backend.")
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
|
|
||||||
simulstreaming_kwargs = {}
|
|
||||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
|
||||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
|
||||||
'max_context_tokens', 'model_path']:
|
|
||||||
if hasattr(args, attr):
|
|
||||||
simulstreaming_kwargs[attr] = getattr(args, attr)
|
|
||||||
|
|
||||||
# Add segment_length from min_chunk_size
|
|
||||||
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
|
|
||||||
simulstreaming_kwargs['task'] = args.task
|
|
||||||
|
|
||||||
size = args.model
|
|
||||||
t = time.time()
|
|
||||||
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
|
|
||||||
asr = SimulStreamingASR(
|
|
||||||
modelsize=size,
|
|
||||||
lan=args.lan,
|
|
||||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
|
||||||
model_dir=getattr(args, 'model_dir', None),
|
|
||||||
**simulstreaming_kwargs
|
|
||||||
)
|
|
||||||
e = time.time()
|
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
|
||||||
else:
|
else:
|
||||||
if backend == "faster-whisper":
|
if backend == "faster-whisper":
|
||||||
asr_cls = FasterWhisperASR
|
asr_cls = FasterWhisperASR
|
||||||
@@ -136,107 +107,4 @@ def backend_factory(args):
|
|||||||
tokenizer = create_tokenizer(tgt_language)
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
return asr, tokenizer
|
return asr, tokenizer
|
||||||
|
|
||||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
|
||||||
if args.backend == "simulstreaming":
|
|
||||||
if not SIMULSTREAMING_ONLINE_AVAILABLE:
|
|
||||||
raise SIMULSTREAMING_ERROR_AND_INSTALLATION_INSTRUCTIONS
|
|
||||||
|
|
||||||
logger.debug("Creating SimulStreaming online processor")
|
|
||||||
online = SimulStreamingOnlineProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation=args.confidence_validation
|
|
||||||
)
|
|
||||||
elif args.vac:
|
|
||||||
online = VACOnlineASRProcessor(
|
|
||||||
args.min_chunk_size,
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
online = OnlineASRProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
return online
|
|
||||||
|
|
||||||
def asr_factory(args, logfile=sys.stderr):
|
|
||||||
"""
|
|
||||||
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
|
||||||
"""
|
|
||||||
asr, tokenizer = backend_factory(args)
|
|
||||||
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
|
||||||
return asr, online
|
|
||||||
|
|
||||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
|
||||||
"""
|
|
||||||
Warmup the ASR model by transcribing a short audio file.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
|
|
||||||
|
|
||||||
if warmup_file is None:
|
|
||||||
# Download JFK sample if not already present
|
|
||||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
|
||||||
|
|
||||||
if not os.path.exists(warmup_file):
|
|
||||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
|
||||||
print(f"Downloading warmup file from {jfk_url}")
|
|
||||||
import time
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
import socket
|
|
||||||
|
|
||||||
original_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(timeout)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
|
||||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
|
||||||
except (urllib.error.URLError, socket.timeout) as e:
|
|
||||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
socket.setdefaulttimeout(original_timeout)
|
|
||||||
elif not warmup_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
|
||||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
|
|
||||||
try:
|
|
||||||
import librosa
|
|
||||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load audio file: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_simulstreaming:
|
|
||||||
asr.warmup(audio)
|
|
||||||
else:
|
|
||||||
asr.transcribe(audio)
|
|
||||||
|
|
||||||
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warmup failed: {e}")
|
|
||||||
return False
|
|
||||||
Reference in New Issue
Block a user