mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-09 15:25:34 +00:00
Compare commits
74 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c0768e8f3 | ||
|
|
b42d8b2692 | ||
|
|
0cd885247c | ||
|
|
8e30e8010a | ||
|
|
bfec335a5f | ||
|
|
6867041254 | ||
|
|
e165916952 | ||
|
|
8532a91c7a | ||
|
|
b01b81bad0 | ||
|
|
0f79d442ee | ||
|
|
c9f60504e3 | ||
|
|
993a83546a | ||
|
|
eabd1b199a | ||
|
|
f7644268c1 | ||
|
|
34e8fe260e | ||
|
|
debfefaf3e | ||
|
|
101ca9ef90 | ||
|
|
94bb05d53e | ||
|
|
6797b88176 | ||
|
|
46770efd6c | ||
|
|
b23ef3ec3e | ||
|
|
fa29a24abe | ||
|
|
fea3c3553c | ||
|
|
d6d65a663b | ||
|
|
083d5b2f44 | ||
|
|
8e4674b093 | ||
|
|
bc7c32100f | ||
|
|
c4150894af | ||
|
|
25bf242ce1 | ||
|
|
14cc601a5c | ||
|
|
34d5d513fa | ||
|
|
2ab3dac948 | ||
|
|
b56fcffde1 | ||
|
|
2def194893 | ||
|
|
29978da301 | ||
|
|
b708890788 | ||
|
|
3ac4c514cf | ||
|
|
3c58bfcfa2 | ||
|
|
d53b7a323a | ||
|
|
02de5993e6 | ||
|
|
d94560ef37 | ||
|
|
f62baa80b7 | ||
|
|
0b43035701 | ||
|
|
704170ccf3 | ||
|
|
09279c572a | ||
|
|
23e41f993f | ||
|
|
c791b1e125 | ||
|
|
3de2990ec4 | ||
|
|
51e6a6f6f9 | ||
|
|
f6e53b2fab | ||
|
|
5d6f08ff7a | ||
|
|
583a26da88 | ||
|
|
5b3d8969e8 | ||
|
|
40cca184c1 | ||
|
|
47ed345f9e | ||
|
|
9c9c179684 | ||
|
|
b870c12f62 | ||
|
|
cfd5905fd4 | ||
|
|
2399487e45 | ||
|
|
afd88310fd | ||
|
|
080f446b0d | ||
|
|
8bd2b36488 | ||
|
|
25fd924bf9 | ||
|
|
ff8fd0ec72 | ||
|
|
e99f53e649 | ||
|
|
e9022894b2 | ||
|
|
ccf99cecdf | ||
|
|
40e2814cd7 | ||
|
|
cd29eace3d | ||
|
|
38cb54640f | ||
|
|
81268a7ca3 | ||
|
|
33cbd24964 | ||
|
|
e966e78584 | ||
|
|
c13d36b5e7 |
82
Dockerfile
Normal file
82
Dockerfile
Normal file
@@ -0,0 +1,82 @@
|
||||
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
# Install system dependencies
|
||||
#RUN apt-get update && \
|
||||
# apt-get install -y ffmpeg git && \
|
||||
# apt-get clean && \
|
||||
# rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 2) Install system dependencies + Python + pip
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
git && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
# Note: For gates modedls, need to add your HF toke. See README.md
|
||||
# for more details.
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir .[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir .; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models by:
|
||||
# Note: If running multiple containers, better to map a shared
|
||||
# bucket.
|
||||
#
|
||||
# A) Make the cache directory persistent via an anonymous volume.
|
||||
# Note: This only persists for a single, named container. This is
|
||||
# only for convenience at de/test stage.
|
||||
# For prod, it is better to use a named volume via host mount/k8s.
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
# or
|
||||
# B) Conditionally copy a local pre-cache from the build context to the
|
||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||
# WARNING: This will copy ALL files in the pre-cache location.
|
||||
|
||||
# Conditionally copy a cache directory if provided
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args
|
||||
CMD ["--model", "tiny.en"]
|
||||
33
LICENSE
33
LICENSE
@@ -1,21 +1,28 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 ÚFAL
|
||||
Copyright (c) 2025 Quentin Fuxa.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
---
|
||||
|
||||
Based on:
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
||||
438
README.md
438
README.md
@@ -1,158 +1,366 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
<p align="center"><b>Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization</b></p>
|
||||
|
||||
|
||||
This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/demo.png" alt="Demo Screenshot" width="730">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
|
||||
|
||||
#### ⚙️ **Core Improvements**
|
||||
- **Buffering Preview** – Displays unvalidated transcription segments
|
||||
- **Multi-User Support** – Handles multiple users simultaneously by decoupling backend and online asr
|
||||
- **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing.
|
||||
- **Confidence validation** – Immediately validate high-confidence tokens for faster inference
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a>
|
||||
</p>
|
||||
|
||||
#### 🎙️ **Speaker Identification**
|
||||
- **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart)
|
||||
## 🚀 Overview
|
||||
|
||||
#### 🌐 **Web & API**
|
||||
- **Built-in Web UI** – Simple raw html browser interface with no frontend setup required
|
||||
- **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
|
||||
- **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
|
||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
||||
|
||||
## Installation
|
||||
### 🔄 Architecture
|
||||
|
||||
### Via pip
|
||||
WhisperLiveKit consists of three main components:
|
||||
|
||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the provided template at [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
||||
|
||||
|
||||
### ✨ Key Features
|
||||
|
||||
- **🎙️ Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
||||
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
||||
- **🌐 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
||||
- **🔇 Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
||||
- **✅ Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
||||
- **👁️ Buffering Preview** – Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
|
||||
- **✒️ Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
||||
- **⚡ SimulStreaming Backend** - Ultra-low latency transcription using state-of-the-art AlignAtt policy. The code is not directly included in the repo : To use, please copy [simul_whisper](https://github.com/ufal/SimulStreaming/tree/main/simul_whisper) content into `whisperlivekit/simul_whisper` . ⚠️ You must comply with the [Polyform license](https://github.com/ufal/SimulStreaming/blob/main/LICENCE.txt)
|
||||
|
||||
|
||||
## 📖 Quick Start
|
||||
|
||||
```bash
|
||||
# Install the package
|
||||
pip install whisperlivekit
|
||||
|
||||
# Start the transcription server
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Open your browser at http://localhost:8000 to see the interface.
|
||||
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
|
||||
```
|
||||
|
||||
That's it! Start speaking and watch your words appear on screen.
|
||||
|
||||
## 🛠️ Installation Options
|
||||
|
||||
### Install from PyPI (Recommended)
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
### From source
|
||||
### Install from Source
|
||||
|
||||
1. **Clone the Repository**:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
||||
cd WhisperLiveKit
|
||||
pip install -e .
|
||||
```
|
||||
```bash
|
||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
||||
cd WhisperLiveKit
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### System Dependencies
|
||||
|
||||
You need to install FFmpeg on your system:
|
||||
FFmpeg is required:
|
||||
|
||||
- Install system dependencies:
|
||||
```bash
|
||||
# Install FFmpeg on your system (required for audio processing)
|
||||
# For Ubuntu/Debian:
|
||||
sudo apt install ffmpeg
|
||||
|
||||
# For macOS:
|
||||
brew install ffmpeg
|
||||
|
||||
# For Windows:
|
||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
||||
```
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install ffmpeg
|
||||
|
||||
- Install required Python dependencies:
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
```bash
|
||||
# Whisper streaming required dependencies
|
||||
pip install librosa soundfile
|
||||
# Windows
|
||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
||||
```
|
||||
|
||||
# Whisper streaming web required dependencies
|
||||
pip install fastapi ffmpeg-python
|
||||
```
|
||||
- Install at least one whisper backend among:
|
||||
### Optional Dependencies
|
||||
|
||||
```
|
||||
whisper
|
||||
whisper-timestamped
|
||||
faster-whisper (faster backend on NVIDIA GPU)
|
||||
mlx-whisper (faster backend on Apple Silicon)
|
||||
```bash
|
||||
# Voice Activity Controller (prevents hallucinations)
|
||||
pip install torch
|
||||
|
||||
# Sentence-based buffer trimming
|
||||
pip install mosestokenizer wtpsplit
|
||||
pip install tokenize_uk # If you work with Ukrainian text
|
||||
|
||||
# Speaker diarization
|
||||
pip install diart
|
||||
|
||||
# Alternative Whisper backends (default is faster-whisper)
|
||||
pip install whisperlivekit[whisper] # Original Whisper
|
||||
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
|
||||
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
|
||||
pip install whisperlivekit[openai] # OpenAI API
|
||||
pip install whisperlivekit[simulstreaming]
|
||||
```
|
||||
|
||||
### 🎹 Pyannote Models Setup
|
||||
|
||||
For diarization, you need access to pyannote.audio models:
|
||||
|
||||
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
||||
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
||||
4. Login with HuggingFace:
|
||||
```bash
|
||||
pip install huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 💻 Usage Examples
|
||||
|
||||
### Command-line Interface
|
||||
|
||||
Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Basic server with English model
|
||||
whisperlivekit-server --model tiny.en
|
||||
|
||||
# Advanced configuration with diarization
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
|
||||
|
||||
# SimulStreaming backend for ultra-low latency
|
||||
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
|
||||
```
|
||||
|
||||
|
||||
### Python API Integration (Backend)
|
||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
# Global variable for the transcription engine
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
# Example: Initialize with specific parameters directly
|
||||
# You can also load from command-line arguments using parse_args()
|
||||
# args = parse_args()
|
||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Serve the web interface
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
|
||||
# Process WebSocket connections
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
except WebSocketDisconnect:
|
||||
print("WebSocket disconnected during results handling.")
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
# Create a new AudioProcessor for each connection, passing the shared engine
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
send_results_to_client = handle_websocket_results(websocket, results_generator)
|
||||
results_task = asyncio.create_task(send_results_to_client)
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
except WebSocketDisconnect:
|
||||
print(f"Client disconnected: {websocket.client}")
|
||||
except Exception as e:
|
||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
||||
finally:
|
||||
results_task.cancel()
|
||||
try:
|
||||
await results_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Results task successfully cancelled.")
|
||||
```
|
||||
|
||||
### Frontend Implementation
|
||||
|
||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`:
|
||||
|
||||
```python
|
||||
from whisperlivekit import get_web_interface_html
|
||||
|
||||
# ... later in your code where you need the HTML string ...
|
||||
html_content = get_web_interface_html()
|
||||
```
|
||||
|
||||
## ⚙️ Configuration Reference
|
||||
|
||||
WhisperLiveKit offers extensive configuration options:
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--host` | Server host address | `localhost` |
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
|
||||
| `--language` | Source language code or `auto` | `en` |
|
||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
||||
| `--backend` | Processing backend | `faster-whisper` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
||||
| `--vac` | Use Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
**SimulStreaming-specific Options:**
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
| `--audio-max-len` | Maximum audio buffer length (seconds) | `30.0` |
|
||||
| `--audio-min-len` | Minimum audio length to process (seconds) | `0.0` |
|
||||
| `--cif-ckpt-path` | Path to CIF model for word boundary detection | `None` |
|
||||
| `--never-fire` | Never truncate incomplete words | `False` |
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
|
||||
2. **Streaming**: Audio chunks are sent to the server via WebSocket
|
||||
3. **Processing**: Server decodes audio with FFmpeg and streams into Whisper for transcription
|
||||
4. **Real-time Output**:
|
||||
- Partial transcriptions appear immediately in light gray (the 'aperçu')
|
||||
- Finalized text appears in normal color
|
||||
- (When enabled) Different speakers are identified and highlighted
|
||||
|
||||
## 🚀 Deployment Guide
|
||||
|
||||
To deploy WhisperLiveKit in production:
|
||||
|
||||
1. **Server Setup** (Backend):
|
||||
```bash
|
||||
# Install production ASGI server
|
||||
pip install uvicorn gunicorn
|
||||
|
||||
# Launch with multiple workers
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
- Optionnal dependencies
|
||||
|
||||
```
|
||||
# If you want to use VAC (Voice Activity Controller). Useful for preventing hallucinations
|
||||
torch
|
||||
|
||||
# If you choose sentences as buffer trimming strategy
|
||||
mosestokenizer
|
||||
wtpsplit
|
||||
tokenize_uk # If you work with Ukrainian text
|
||||
2. **Frontend Integration**:
|
||||
- Host your customized version of the example HTML/JS in your web application
|
||||
- Ensure WebSocket connection points to your server's address
|
||||
|
||||
# If you want to run the server using uvicorn (recommended)
|
||||
uvicorn
|
||||
3. **Nginx Configuration** (recommended for production):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
|
||||
# If you want to use diarization
|
||||
diart
|
||||
```
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Diart uses by default [pyannote.audio](https://github.com/pyannote/pyannote-audio) models from the _huggingface hub_. To use them, please follow the steps described [here](https://github.com/juanmc2005/diart?tab=readme-ov-file#get-access-to--pyannote-models).
|
||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||
|
||||
### 🐋 Docker
|
||||
|
||||
3. **Run the FastAPI Server**:
|
||||
A basic Dockerfile is provided which allows re-use of Python package installation options. See below usage examples:
|
||||
|
||||
```bash
|
||||
python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000
|
||||
```
|
||||
**NOTE:** For **larger** models, ensure that your **docker runtime** has enough **memory** available.
|
||||
|
||||
**Parameters**
|
||||
|
||||
The following parameters are supported:
|
||||
|
||||
- `--host` and `--port` let you specify the server's IP/port.
|
||||
- `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
|
||||
- `--transcription`: Enable/disable transcription (default: True)
|
||||
- `--diarization`: Enable/disable speaker diarization (default: False)
|
||||
- `--confidence-validation`: Use confidence scores for faster validation. Transcription will be faster but punctuation might be less accurate (default: True)
|
||||
- `--warmup-file`: The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. :
|
||||
- If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
- If False, no warmup is performed.
|
||||
- `--min-chunk-size` Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.
|
||||
- `--model` {_tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large-v3, large, large-v3-turbo_}
|
||||
Name size of the Whisper model to use (default: tiny). The model is automatically downloaded from the model hub if not present in model cache dir.
|
||||
- `--model_cache_dir` Overriding the default model cache dir where models downloaded from the hub are saved
|
||||
- `--model_dir` Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
|
||||
- `--lan`, --language Source language code, e.g. en,de,cs, or 'auto' for language detection.
|
||||
- `--task` {_transcribe, translate_} Transcribe or translate. If translate is set, we recommend avoiding the _large-v3-turbo_ backend, as it [performs significantly worse](https://github.com/QuentinFuxa/whisper_streaming_web/issues/40#issuecomment-2652816533) than other models for translation.
|
||||
- `--backend` {_faster-whisper, whisper_timestamped, openai-api, mlx-whisper_} Load only this backend for Whisper processing.
|
||||
- `--vac` Use VAC = voice activity controller. Requires torch.
|
||||
- `--vac-chunk-size` VAC sample size in seconds.
|
||||
- `--vad` Use VAD = voice activity detection, with the default parameters.
|
||||
- `--buffer_trimming` {_sentence, segment_} Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.
|
||||
- `--buffer_trimming_sec` Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.
|
||||
#### All defaults
|
||||
- Create a reusable image with only the basics and then run as a named container:
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
||||
docker start -i whisperlivekit
|
||||
```
|
||||
|
||||
5. **Open the Provided HTML**:
|
||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||
|
||||
- By default, the server root endpoint `/` serves a simple `live_transcription.html` page.
|
||||
- Open your browser at `http://localhost:8000` (or replace `localhost` and `8000` with whatever you specified).
|
||||
- The page uses vanilla JavaScript and the WebSocket API to capture your microphone and stream audio to the server in real time.
|
||||
#### Customization
|
||||
- Customize the container options:
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
||||
docker start -i whisperlivekit-base
|
||||
```
|
||||
|
||||
### How the Live Interface Works
|
||||
- `--build-arg` Options:
|
||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||
- `HF_TOKEN="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||
|
||||
- Once you **allow microphone access**, the page records small chunks of audio using the **MediaRecorder** API in **webm/opus** format.
|
||||
- These chunks are sent over a **WebSocket** to the FastAPI endpoint at `/asr`.
|
||||
- The Python server decodes `.webm` chunks on the fly using **FFmpeg** and streams them into the **whisper streaming** implementation for transcription.
|
||||
- **Partial transcription** appears as soon as enough audio is processed. The “unvalidated” text is shown in **lighter or grey color** (i.e., an ‘aperçu’) to indicate it’s still buffered partial output. Once Whisper finalizes that segment, it’s displayed in normal text.
|
||||
- You can watch the transcription update in near real time, ideal for demos, prototyping, or quick debugging.
|
||||
## 🔮 Use Cases
|
||||
|
||||
### Deploying to a Remote Server
|
||||
- **Meeting Transcription**: Capture discussions in real-time
|
||||
- **Accessibility Tools**: Help hearing-impaired users follow conversations
|
||||
- **Content Creation**: Transcribe podcasts or videos automatically
|
||||
- **Customer Service**: Transcribe support calls with speaker identification
|
||||
|
||||
If you want to **deploy** this setup:
|
||||
## 📄 License
|
||||
|
||||
1. **Host the FastAPI app** behind a production-grade HTTP(S) server (like **Uvicorn + Nginx** or Docker). If you use HTTPS, use "wss" instead of "ws" in WebSocket URL.
|
||||
2. The **HTML/JS page** can be served by the same FastAPI app or a separate static host.
|
||||
3. Users open the page in **Chrome/Firefox** (any modern browser that supports MediaRecorder + WebSocket).
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
No additional front-end libraries or frameworks are required. The WebSocket logic in `live_transcription.html` is minimal enough to adapt for your own custom UI or embed in other pages.
|
||||
**⚠️ Important**: When using the SimulStreaming backend, you must also comply with the **PolyForm Noncommercial License 1.0.0** that governs SimulStreaming. For commercial use of the SimulStreaming backend, obtain a commercial license from the [SimulStreaming authors](https://github.com/ufal/SimulStreaming#-licence-and-contributions).
|
||||
|
||||
## Acknowledgments
|
||||
## 🤝 Contributing
|
||||
|
||||
This project builds upon the foundational work of the Whisper Streaming project. We extend our gratitude to the original authors for their contributions.
|
||||
Contributions are welcome! Here's how to get started:
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch: `git checkout -b feature/amazing-feature`
|
||||
3. Commit your changes: `git commit -m 'Add amazing feature'`
|
||||
4. Push to your branch: `git push origin feature/amazing-feature`
|
||||
5. Open a Pull Request
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This project builds upon the foundational work of:
|
||||
- [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (BETA backend)
|
||||
- [Diart](https://github.com/juanmc2005/diart)
|
||||
- [OpenAI Whisper](https://github.com/openai/whisper)
|
||||
|
||||
We extend our gratitude to the original authors for their contributions.
|
||||
|
||||
## 🔗 Links
|
||||
|
||||
- [GitHub Repository](https://github.com/QuentinFuxa/WhisperLiveKit)
|
||||
- [PyPI Package](https://pypi.org/project/whisperlivekit/)
|
||||
- [Issue Tracker](https://github.com/QuentinFuxa/WhisperLiveKit/issues)
|
||||
BIN
demo.png
BIN
demo.png
Binary file not shown.
|
Before Width: | Height: | Size: 469 KiB After Width: | Height: | Size: 438 KiB |
16
setup.py
16
setup.py
@@ -1,8 +1,7 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="whisperlivekit",
|
||||
version="0.1.0",
|
||||
version="0.2.1",
|
||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
@@ -22,13 +21,24 @@ setup(
|
||||
"diarization": ["diart"],
|
||||
"vac": ["torch"],
|
||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
||||
"whisper": ["whisper"],
|
||||
"whisper-timestamped": ["whisper-timestamped"],
|
||||
"mlx-whisper": ["mlx-whisper"],
|
||||
"openai": ["openai"],
|
||||
"simulstreaming": [
|
||||
"torch",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
||||
],
|
||||
},
|
||||
package_data={
|
||||
'whisperlivekit': ['web/*.html'],
|
||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
||||
},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'whisperlivekit-server=whisperlivekit.server:run_server',
|
||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from whisperlivekit import WhisperLiveKit
|
||||
from whisperlivekit.audio_processor import AudioProcessor
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
kit = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global kit
|
||||
kit = WhisperLiveKit()
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(kit.web_interface())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
audio_processor = AudioProcessor()
|
||||
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning("WebSocket disconnected.")
|
||||
finally:
|
||||
websocket_task.cancel()
|
||||
await audio_processor.cleanup()
|
||||
logger.info("WebSocket endpoint cleaned up.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
temp_kit = WhisperLiveKit(transcription=False, diarization=False)
|
||||
|
||||
uvicorn.run(
|
||||
"whisper_fastapi_online_server:app",
|
||||
host=temp_kit.args.host,
|
||||
port=temp_kit.args.port,
|
||||
reload=True,
|
||||
log_level="info"
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
from .core import WhisperLiveKit, parse_args
|
||||
from .core import TranscriptionEngine
|
||||
from .audio_processor import AudioProcessor
|
||||
|
||||
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
|
||||
from .web.web_interface import get_web_interface_html
|
||||
from .parse_args import parse_args
|
||||
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']
|
||||
@@ -6,16 +6,17 @@ import math
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import List, Dict, Any
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import online_factory
|
||||
from whisperlivekit.core import WhisperLiveKit
|
||||
from whisperlivekit.core import TranscriptionEngine
|
||||
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
@@ -26,10 +27,13 @@ class AudioProcessor:
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
models = WhisperLiveKit()
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
models = kwargs['transcription_engine']
|
||||
else:
|
||||
models = TranscriptionEngine(**kwargs)
|
||||
|
||||
# Audio processing settings
|
||||
self.args = models.args
|
||||
@@ -39,8 +43,12 @@ class AudioProcessor:
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
|
||||
self.last_ffmpeg_activity = time()
|
||||
self.ffmpeg_health_check_interval = 5
|
||||
self.ffmpeg_max_idle_time = 10
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.tokens = []
|
||||
self.buffer_transcription = ""
|
||||
self.buffer_diarization = ""
|
||||
@@ -60,6 +68,13 @@ class AudioProcessor:
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.pcm_buffer = bytearray()
|
||||
|
||||
# Task references
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
|
||||
# Initialize transcription engine if enabled
|
||||
if self.args.transcription:
|
||||
@@ -71,21 +86,80 @@ class AudioProcessor:
|
||||
|
||||
def start_ffmpeg_decoder(self):
|
||||
"""Start FFmpeg process for WebM to PCM conversion."""
|
||||
return (ffmpeg.input("pipe:0", format="webm")
|
||||
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
||||
ac=self.channels, ar=str(self.sample_rate))
|
||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
||||
try:
|
||||
return (ffmpeg.input("pipe:0", format="webm")
|
||||
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
||||
ac=self.channels, ar=str(self.sample_rate))
|
||||
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
||||
except FileNotFoundError:
|
||||
error = """
|
||||
FFmpeg is not installed or not found in your system's PATH.
|
||||
Please install FFmpeg to enable audio processing.
|
||||
|
||||
Installation instructions:
|
||||
|
||||
# Ubuntu/Debian:
|
||||
sudo apt update && sudo apt install ffmpeg
|
||||
|
||||
# macOS (using Homebrew):
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows:
|
||||
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||
|
||||
After installation, please restart the application.
|
||||
"""
|
||||
logger.error(error)
|
||||
raise FileNotFoundError(error)
|
||||
|
||||
async def restart_ffmpeg(self):
|
||||
"""Restart the FFmpeg process after failure."""
|
||||
logger.warning("Restarting FFmpeg process...")
|
||||
|
||||
if self.ffmpeg_process:
|
||||
try:
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||
# we check if process is still running
|
||||
if self.ffmpeg_process.poll() is None:
|
||||
logger.info("Terminating existing FFmpeg process")
|
||||
self.ffmpeg_process.stdin.close()
|
||||
self.ffmpeg_process.terminate()
|
||||
|
||||
# wait for termination with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg process did not terminate, killing forcefully")
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error killing FFmpeg process: {e}")
|
||||
logger.error(f"Error during FFmpeg process termination: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# we start new process
|
||||
try:
|
||||
logger.info("Starting new FFmpeg process")
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
self.pcm_buffer = bytearray()
|
||||
self.last_ffmpeg_activity = time()
|
||||
logger.info("FFmpeg process restarted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart FFmpeg process: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# try again after 5s
|
||||
await asyncio.sleep(5)
|
||||
try:
|
||||
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
||||
self.pcm_buffer = bytearray()
|
||||
self.last_ffmpeg_activity = time()
|
||||
logger.info("FFmpeg process restarted successfully on second attempt")
|
||||
except Exception as e2:
|
||||
logger.critical(f"Failed to restart FFmpeg process on second attempt: {e2}")
|
||||
logger.critical(traceback.format_exc())
|
||||
|
||||
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
@@ -154,25 +228,25 @@ class AudioProcessor:
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Calculate buffer size based on elapsed time
|
||||
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
||||
current_time = time()
|
||||
elapsed_time = math.floor((current_time - beg) * 10) / 10
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096)
|
||||
beg = time()
|
||||
beg = current_time
|
||||
|
||||
# Read chunk with timeout
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
|
||||
timeout=15.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout. Restarting...")
|
||||
# Detect idle state much more quickly
|
||||
if current_time - self.last_ffmpeg_activity > self.ffmpeg_max_idle_time:
|
||||
logger.warning(f"FFmpeg process idle for {current_time - self.last_ffmpeg_activity:.2f}s. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
beg = time()
|
||||
self.last_ffmpeg_activity = time()
|
||||
continue
|
||||
|
||||
chunk = await loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size)
|
||||
if chunk:
|
||||
self.last_ffmpeg_activity = time()
|
||||
|
||||
if not chunk:
|
||||
logger.info("FFmpeg stdout closed.")
|
||||
logger.info("FFmpeg stdout closed, no more data to read.")
|
||||
break
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
@@ -183,7 +257,7 @@ class AudioProcessor:
|
||||
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
||||
)
|
||||
|
||||
# Process when we have enough data
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
@@ -207,45 +281,86 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
break
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors.")
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into transcription_queue.")
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into diarization_queue.")
|
||||
|
||||
|
||||
async def transcription_processor(self):
|
||||
"""Process audio chunks for transcription."""
|
||||
self.full_transcription = ""
|
||||
self.sep = self.online.asr.sep
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.transcription_queue.get()
|
||||
if pcm_array is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
|
||||
if not self.online: # Should not happen if queue is used
|
||||
logger.warning("Transcription processor: self.online not initialized.")
|
||||
self.transcription_queue.task_done()
|
||||
continue
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
|
||||
logger.info(
|
||||
f"ASR processing: internal_buffer={asr_internal_buffer_duration_s:.2f}s, "
|
||||
f"lag={transcription_lag_s:.2f}s."
|
||||
)
|
||||
|
||||
# Process transcription
|
||||
self.online.insert_audio_chunk(pcm_array)
|
||||
new_tokens = self.online.process_iter()
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate if isinstance(pcm_array, np.ndarray) else 0
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
|
||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||
|
||||
if new_tokens:
|
||||
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
||||
|
||||
# Get buffer information
|
||||
_buffer = self.online.get_buffer()
|
||||
buffer = _buffer.text
|
||||
end_buffer = _buffer.end if _buffer.end else (
|
||||
new_tokens[-1].end if new_tokens else 0
|
||||
)
|
||||
|
||||
# Avoid duplicating content
|
||||
if buffer in self.full_transcription:
|
||||
buffer = ""
|
||||
_buffer_transcript_obj = self.online.get_buffer()
|
||||
buffer_text = _buffer_transcript_obj.text
|
||||
|
||||
if new_tokens:
|
||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||
self.full_transcription += validated_text
|
||||
|
||||
if buffer_text.startswith(validated_text):
|
||||
buffer_text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
|
||||
if _buffer_transcript_obj.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript_obj.end)
|
||||
|
||||
candidate_end_times.append(current_audio_processed_upto)
|
||||
|
||||
new_end_buffer = max(candidate_end_times)
|
||||
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer, end_buffer, self.full_transcription, self.sep
|
||||
new_tokens, buffer_text, new_end_buffer, self.full_transcription, self.sep
|
||||
)
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.transcription_queue.task_done()
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
||||
self.transcription_queue.task_done()
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
@@ -254,23 +369,33 @@ class AudioProcessor:
|
||||
while True:
|
||||
try:
|
||||
pcm_array = await self.diarization_queue.get()
|
||||
if pcm_array is SENTINEL:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
# Get current state and update speakers
|
||||
state = await self.get_current_state()
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
state["end_attributed_speaker"], state["tokens"]
|
||||
)
|
||||
async with self.lock:
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
self.end_attributed_speaker,
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
self.end_attributed_speaker = new_end
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
await self.update_diarization(new_end, buffer_diarization)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.diarization_queue.task_done()
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
@@ -334,31 +459,51 @@ class AudioProcessor:
|
||||
await self.update_diarization(end_attributed_speaker, combined)
|
||||
buffer_diarization = combined
|
||||
|
||||
# Create response object
|
||||
if not lines:
|
||||
lines = [{
|
||||
response_status = "active_transcription"
|
||||
final_lines_for_response = lines.copy()
|
||||
|
||||
if not tokens and not buffer_transcription and not buffer_diarization:
|
||||
response_status = "no_audio_detected"
|
||||
final_lines_for_response = []
|
||||
elif response_status == "active_transcription" and not final_lines_for_response:
|
||||
final_lines_for_response = [{
|
||||
"speaker": 1,
|
||||
"text": "",
|
||||
"beg": format_time(0),
|
||||
"end": format_time(tokens[-1].end if tokens else 0),
|
||||
"beg": format_time(state.get("end_buffer", 0)),
|
||||
"end": format_time(state.get("end_buffer", 0)),
|
||||
"diff": 0
|
||||
}]
|
||||
|
||||
response = {
|
||||
"lines": lines,
|
||||
"status": response_status,
|
||||
"lines": final_lines_for_response,
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||
}
|
||||
|
||||
# Only yield if content has changed
|
||||
response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
|
||||
f" | {buffer_transcription} | {buffer_diarization}"
|
||||
current_response_signature = f"{response_status} | " + \
|
||||
' '.join([f"{line['speaker']} {line['text']}" for line in final_lines_for_response]) + \
|
||||
f" | {buffer_transcription} | {buffer_diarization}"
|
||||
|
||||
if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
|
||||
if current_response_signature != self.last_response_content and \
|
||||
(final_lines_for_response or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
yield response
|
||||
self.last_response_content = response_content
|
||||
self.last_response_content = current_response_signature
|
||||
|
||||
# Check for termination condition
|
||||
if self.is_stopping:
|
||||
all_processors_done = True
|
||||
if self.args.transcription and self.transcription_task and not self.transcription_task.done():
|
||||
all_processors_done = False
|
||||
if self.args.diarization and self.diarization_task and not self.diarization_task.done():
|
||||
all_processors_done = False
|
||||
|
||||
if all_processors_done:
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
final_state = await self.get_current_state()
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
||||
|
||||
@@ -366,44 +511,169 @@ class AudioProcessor:
|
||||
logger.warning(f"Exception in results_formatter: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5) # Back off on error
|
||||
|
||||
|
||||
async def create_tasks(self):
|
||||
"""Create and start processing tasks."""
|
||||
|
||||
tasks = []
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
|
||||
if self.args.transcription and self.online:
|
||||
tasks.append(asyncio.create_task(self.transcription_processor()))
|
||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||
|
||||
if self.args.diarization and self.diarization:
|
||||
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization))
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
|
||||
self.tasks = tasks
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
|
||||
# Monitor overall system health
|
||||
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
||||
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
current_time = time()
|
||||
|
||||
for i, task in enumerate(tasks_to_monitor):
|
||||
if task.done():
|
||||
exc = task.exception()
|
||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||
if exc:
|
||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||
else:
|
||||
logger.info(f"{task_name} completed normally.")
|
||||
|
||||
ffmpeg_idle_time = current_time - self.last_ffmpeg_activity
|
||||
if ffmpeg_idle_time > 15:
|
||||
logger.warning(f"FFmpeg idle for {ffmpeg_idle_time:.2f}s - may need attention.")
|
||||
if ffmpeg_idle_time > 30 and not self.is_stopping:
|
||||
logger.error("FFmpeg idle for too long and not in stopping phase, forcing restart.")
|
||||
await self.restart_ffmpeg()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog task cancelled.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when processing is complete."""
|
||||
for task in self.tasks:
|
||||
task.cancel()
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
for task in self.all_tasks_for_cleanup:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
|
||||
if self.ffmpeg_process:
|
||||
if self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||
try:
|
||||
self.ffmpeg_process.stdin.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing ffmpeg stdin during cleanup: {e}")
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self.tasks, return_exceptions=True)
|
||||
self.ffmpeg_process.stdin.close()
|
||||
self.ffmpeg_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cleanup: {e}")
|
||||
|
||||
if self.args.diarization and hasattr(self, 'diarization'):
|
||||
# Wait for ffmpeg process to terminate
|
||||
if self.ffmpeg_process.poll() is None: # Check if process is still running
|
||||
logger.info("Waiting for FFmpeg process to terminate...")
|
||||
try:
|
||||
# Run wait in executor to avoid blocking async loop
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait, 5.0) # 5s timeout
|
||||
except Exception as e: # subprocess.TimeoutExpired is not directly caught by asyncio.wait_for with run_in_executor
|
||||
logger.warning(f"FFmpeg did not terminate gracefully, killing. Error: {e}")
|
||||
self.ffmpeg_process.kill()
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) # Wait for kill
|
||||
logger.info("FFmpeg process terminated.")
|
||||
|
||||
if self.args.diarization and hasattr(self, 'diarization') and hasattr(self.diarization, 'close'):
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
"""Process incoming audio data."""
|
||||
try:
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
except (BrokenPipeError, AttributeError) as e:
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
self.ffmpeg_process.stdin.write(message)
|
||||
self.ffmpeg_process.stdin.flush()
|
||||
# If already stopping or stdin is closed, ignore further audio, especially residual chunks.
|
||||
if self.is_stopping or (self.ffmpeg_process and self.ffmpeg_process.stdin and self.ffmpeg_process.stdin.closed):
|
||||
logger.warning(f"AudioProcessor is stopping or stdin is closed. Ignoring incoming audio message (length: {len(message)}).")
|
||||
if not message and self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||
logger.info("Received empty message while already in stopping state; ensuring stdin is closed.")
|
||||
try:
|
||||
self.ffmpeg_process.stdin.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing ffmpeg stdin on redundant stop signal during stopping state: {e}")
|
||||
return
|
||||
|
||||
if not message: # primary signal to start stopping
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
self.is_stopping = True
|
||||
if self.ffmpeg_process and self.ffmpeg_process.stdin and not self.ffmpeg_process.stdin.closed:
|
||||
try:
|
||||
self.ffmpeg_process.stdin.close()
|
||||
logger.info("FFmpeg stdin closed due to primary stop signal.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing ffmpeg stdin on stop: {e}")
|
||||
return
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
# Log periodic heartbeats showing ongoing audio proc
|
||||
current_time = time()
|
||||
if not hasattr(self, '_last_heartbeat') or current_time - self._last_heartbeat >= 10:
|
||||
logger.debug(f"Processing audio chunk, last FFmpeg activity: {current_time - self.last_ffmpeg_activity:.2f}s ago")
|
||||
self._last_heartbeat = current_time
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
if not self.ffmpeg_process or not hasattr(self.ffmpeg_process, 'stdin') or self.ffmpeg_process.poll() is not None:
|
||||
logger.warning("FFmpeg process not available, restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, lambda: self.ffmpeg_process.stdin.write(message)),
|
||||
timeout=2.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg write operation timed out, restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
retry_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self.ffmpeg_process.stdin.flush),
|
||||
timeout=2.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg flush operation timed out, restarting...")
|
||||
await self.restart_ffmpeg()
|
||||
retry_count += 1
|
||||
continue
|
||||
|
||||
self.last_ffmpeg_activity = time()
|
||||
return
|
||||
|
||||
except (BrokenPipeError, AttributeError, OSError) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Error writing to FFmpeg: {e}. Retry {retry_count}/{max_retries}...")
|
||||
|
||||
if retry_count < max_retries:
|
||||
await self.restart_ffmpeg()
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
logger.error("Maximum retries reached for FFmpeg process")
|
||||
await self.restart_ffmpeg()
|
||||
return
|
||||
|
||||
120
whisperlivekit/basic_server.py
Normal file
120
whisperlivekit/basic_server.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
args = parse_args()
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_web_interface_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
# when the results_generator finishes it means all audio has been processed
|
||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
audio_processor = AudioProcessor(
|
||||
transcription_engine=transcription_engine,
|
||||
)
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
except KeyError as e:
|
||||
if 'bytes' in str(e):
|
||||
logger.warning(f"Client has closed the connection.")
|
||||
else:
|
||||
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected by client during message receiving loop.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info("Cleaning up WebSocket endpoint...")
|
||||
if not websocket_task.done():
|
||||
websocket_task.cancel()
|
||||
try:
|
||||
await websocket_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("WebSocket results handler task was cancelled.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
||||
|
||||
await audio_processor.cleanup()
|
||||
logger.info("WebSocket endpoint cleaned up successfully.")
|
||||
|
||||
def main():
|
||||
"""Entry point for the CLI command."""
|
||||
import uvicorn
|
||||
|
||||
uvicorn_kwargs = {
|
||||
"app": "whisperlivekit.basic_server:app",
|
||||
"host":args.host,
|
||||
"port":args.port,
|
||||
"reload": False,
|
||||
"log_level": "info",
|
||||
"lifespan": "on",
|
||||
}
|
||||
|
||||
ssl_kwargs = {}
|
||||
if args.ssl_certfile or args.ssl_keyfile:
|
||||
if not (args.ssl_certfile and args.ssl_keyfile):
|
||||
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
||||
ssl_kwargs = {
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile
|
||||
}
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
|
||||
uvicorn.run(**uvicorn_kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,139 +1,11 @@
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
from argparse import Namespace, ArgumentParser
|
||||
try:
|
||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
except ImportError:
|
||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
||||
from argparse import Namespace
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="warmup_file",
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to enable speaker diarization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcription",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="To disable to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
||||
","
|
||||
),
|
||||
help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use VAD = voice activity detection, with the default parameters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
default="segment",
|
||||
choices=["sentence", "segment"],
|
||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming_sec",
|
||||
type=float,
|
||||
default=15,
|
||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
dest="log_level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level",
|
||||
default="DEBUG",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
class WhisperLiveKit:
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
@@ -143,14 +15,63 @@ class WhisperLiveKit:
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if WhisperLiveKit._initialized:
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
default_args = vars(parse_args())
|
||||
|
||||
defaults = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"warmup_file": None,
|
||||
"confidence_validation": False,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"min_chunk_size": 0.5,
|
||||
"model": "tiny",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"lan": "auto",
|
||||
"task": "transcribe",
|
||||
"backend": "faster-whisper",
|
||||
"vac": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"buffer_trimming": "segment",
|
||||
"buffer_trimming_sec": 15,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
# simulstreaming params:
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 30.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
|
||||
if 'no_transcription' in kwargs:
|
||||
config_dict['transcription'] = not kwargs['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
config_dict['vad'] = not kwargs['no_vad']
|
||||
|
||||
merged_args = {**default_args, **kwargs}
|
||||
|
||||
self.args = Namespace(**merged_args)
|
||||
config_dict.pop('no_transcription', None)
|
||||
config_dict.pop('no_vad', None)
|
||||
|
||||
if 'language' in kwargs:
|
||||
config_dict['lan'] = kwargs['language']
|
||||
config_dict.pop('language', None)
|
||||
|
||||
self.args = Namespace(**config_dict)
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
@@ -162,13 +83,10 @@ class WhisperLiveKit:
|
||||
|
||||
if self.args.diarization:
|
||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
||||
self.diarization = DiartDiarization()
|
||||
self.diarization = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
|
||||
WhisperLiveKit._initialized = True
|
||||
|
||||
def web_interface(self):
|
||||
import pkg_resources
|
||||
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
|
||||
with open(html_path, "r", encoding="utf-8") as f:
|
||||
html = f.read()
|
||||
return html
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
0
whisperlivekit/diarization/__init__.py
Normal file
0
whisperlivekit/diarization/__init__.py
Normal file
@@ -3,7 +3,8 @@ import re
|
||||
import threading
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
import time
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.inference import StreamingInference
|
||||
@@ -13,6 +14,7 @@ from diart.sources import MicrophoneAudioSource
|
||||
from rx.core import Observer
|
||||
from typing import Tuple, Any, List
|
||||
from pyannote.core import Annotation
|
||||
import diart.models as m
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,40 +80,114 @@ class DiarizationObserver(Observer):
|
||||
|
||||
class WebSocketAudioSource(AudioSource):
|
||||
"""
|
||||
Custom AudioSource that blocks in read() until close() is called.
|
||||
Use push_audio() to inject PCM chunks.
|
||||
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
||||
"""
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
||||
super().__init__(uri, sample_rate)
|
||||
self.block_duration = block_duration
|
||||
self.block_size = int(np.rint(block_duration * sample_rate))
|
||||
self._queue = SimpleQueue()
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
self._buffer_lock = threading.Lock()
|
||||
self._closed = False
|
||||
self._close_event = threading.Event()
|
||||
self._processing_thread = None
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
def read(self):
|
||||
"""Start processing buffered audio and emit fixed-size chunks."""
|
||||
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||
self._processing_thread.daemon = True
|
||||
self._processing_thread.start()
|
||||
|
||||
self._close_event.wait()
|
||||
if self._processing_thread:
|
||||
self._processing_thread.join(timeout=2.0)
|
||||
|
||||
def _process_chunks(self):
|
||||
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
||||
while not self._closed:
|
||||
try:
|
||||
audio_chunk = self._queue.get(timeout=0.1)
|
||||
|
||||
with self._buffer_lock:
|
||||
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||
|
||||
while len(self._buffer) >= self.block_size:
|
||||
chunk = self._buffer[:self.block_size]
|
||||
self._buffer = self._buffer[self.block_size:]
|
||||
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self._last_chunk_time
|
||||
if time_since_last < self.block_duration:
|
||||
time.sleep(self.block_duration - time_since_last)
|
||||
|
||||
chunk_reshaped = chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
except Empty:
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio processing thread: {e}")
|
||||
self.stream.on_error(e)
|
||||
break
|
||||
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
|
||||
self.stream.on_completed()
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self.stream.on_completed()
|
||||
self._close_event.set()
|
||||
|
||||
def push_audio(self, chunk: np.ndarray):
|
||||
"""Add audio chunk to the processing queue."""
|
||||
if not self._closed:
|
||||
new_audio = np.expand_dims(chunk, axis=0)
|
||||
logger.debug('Add new chunk with shape:', new_audio.shape)
|
||||
self.stream.on_next(new_audio)
|
||||
if chunk.ndim > 1:
|
||||
chunk = chunk.flatten()
|
||||
self._queue.put(chunk)
|
||||
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||
|
||||
if config is None:
|
||||
config = SpeakerDiarizationConfig(
|
||||
segmentation=segmentation_model,
|
||||
embedding=embedding_model,
|
||||
)
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource()
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
self.custom_source = None
|
||||
else:
|
||||
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
||||
self.custom_source = WebSocketAudioSource(
|
||||
uri="websocket_source",
|
||||
sample_rate=sample_rate,
|
||||
block_duration=block_duration
|
||||
)
|
||||
self.source = self.custom_source
|
||||
|
||||
self.inference = StreamingInference(
|
||||
@@ -138,16 +214,102 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
return end_attributed_speaker
|
||||
|
||||
if use_punctuation_split and len(tokens) > 1:
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
|
||||
print("Here are the tokens:",
|
||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
||||
|
||||
segment_map = []
|
||||
for segment in segments:
|
||||
speaker_num = extract_number(segment.speaker) + 1
|
||||
segment_map.append((segment.start, segment.end, speaker_num))
|
||||
segment_map.sort(key=lambda x: x[0])
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
current_token = tokens[i]
|
||||
|
||||
is_sentence_end = False
|
||||
if current_token.text and current_token.text.strip():
|
||||
text = current_token.text.strip()
|
||||
if text[-1] in punctuation_marks:
|
||||
is_sentence_end = True
|
||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
||||
|
||||
if is_sentence_end and current_token.speaker != -1:
|
||||
punctuation_time = current_token.end
|
||||
current_speaker = current_token.speaker
|
||||
|
||||
j = i + 1
|
||||
next_sentence_tokens = []
|
||||
while j < len(tokens):
|
||||
next_token = tokens[j]
|
||||
next_sentence_tokens.append(j)
|
||||
|
||||
# Check if this token ends the next sentence
|
||||
if next_token.text and next_token.text.strip():
|
||||
if next_token.text.strip()[-1] in punctuation_marks:
|
||||
break
|
||||
j += 1
|
||||
|
||||
if next_sentence_tokens:
|
||||
speaker_times = {}
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
token = tokens[idx]
|
||||
# Find which segments overlap with this token
|
||||
for seg_start, seg_end, seg_speaker in segment_map:
|
||||
if not (seg_end <= token.start or seg_start >= token.end):
|
||||
# Calculate overlap duration
|
||||
overlap_start = max(seg_start, token.start)
|
||||
overlap_end = min(seg_end, token.end)
|
||||
overlap_duration = overlap_end - overlap_start
|
||||
|
||||
if seg_speaker not in speaker_times:
|
||||
speaker_times[seg_speaker] = 0
|
||||
speaker_times[seg_speaker] += overlap_duration
|
||||
|
||||
if speaker_times:
|
||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
||||
|
||||
if dominant_speaker != current_speaker:
|
||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
||||
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker != dominant_speaker:
|
||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
||||
tokens[idx].speaker = dominant_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
else:
|
||||
for idx in next_sentence_tokens:
|
||||
if tokens[idx].speaker == -1:
|
||||
tokens[idx].speaker = current_speaker
|
||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
||||
|
||||
i += 1
|
||||
|
||||
return end_attributed_speaker
|
||||
|
||||
253
whisperlivekit/parse_args.py
Normal file
253
whisperlivekit/parse_args.py
Normal file
@@ -0,0 +1,253 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="warmup_file",
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If False, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
action="store_true",
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable speaker diarization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--punctuation-split",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segmentation-model",
|
||||
type=str,
|
||||
default="pyannote/segmentation-3.0",
|
||||
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="pyannote/embedding",
|
||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="tiny",
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
help="Transcribe or translate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="faster-whisper",
|
||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||
help="Load only this backend for Whisper processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-vad",
|
||||
action="store_true",
|
||||
help="Disable VAD (voice activity detection).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
default="segment",
|
||||
choices=["sentence", "segment"],
|
||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming_sec",
|
||||
type=float,
|
||||
default=15,
|
||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
dest="log_level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level",
|
||||
default="DEBUG",
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
type=int,
|
||||
default=25,
|
||||
dest="frame_threshold",
|
||||
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--beams",
|
||||
"-b",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="decoder_type",
|
||||
choices=["beam", "greedy"],
|
||||
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-max-len",
|
||||
type=float,
|
||||
default=30.0,
|
||||
dest="audio_max_len",
|
||||
help="Max length of the audio buffer, in seconds.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-min-len",
|
||||
type=float,
|
||||
default=0.0,
|
||||
dest="audio_min_len",
|
||||
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--cif-ckpt-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="cif_ckpt_path",
|
||||
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--never-fire",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="never_fire",
|
||||
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--init-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="init_prompt",
|
||||
help="Init prompt for the model. It should be in the target language.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--static-init-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="static_init_prompt",
|
||||
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--max-context-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
dest="max_context_tokens",
|
||||
help="Max context tokens for the model. Default is 0.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="model_path",
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
return args
|
||||
27
whisperlivekit/simul_whisper/dual_license_simulstreaming.md
Normal file
27
whisperlivekit/simul_whisper/dual_license_simulstreaming.md
Normal file
@@ -0,0 +1,27 @@
|
||||
|
||||
|
||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
||||
|
||||
SimulStreaming is dual-licensed:
|
||||
|
||||
🔹 Non-Commercial Use
|
||||
|
||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
||||
obtain the code through the GitHub repository. This license is **free of charge**
|
||||
and comes with **no obligations** for non-commercial users.
|
||||
|
||||
🔸 Commercial Use
|
||||
|
||||
Understanding who uses SimulStreaming commercially helps us improve and
|
||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
||||
|
||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
||||
are considering to provide commercial licenses either for free or for symbolic
|
||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
||||
|
||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
||||
available.
|
||||
|
||||
✉️ Contact
|
||||
|
||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
||||
@@ -26,4 +26,7 @@ class Transcript(TimedText):
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(TimedText):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
pass
|
||||
73
whisperlivekit/token_buffer.py
Normal file
73
whisperlivekit/token_buffer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import sys
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
self.text = text
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_tensor(self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if device is None:
|
||||
raise ValueError("Device is not set.")
|
||||
tok_ids = self.as_token_ids()
|
||||
return torch.tensor(tok_ids,
|
||||
dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
def as_tensor_beam(self, beam, device=None):
|
||||
t = self.as_tensor(device=device)
|
||||
return t.repeat_interleave(beam, dim=0)
|
||||
|
||||
|
||||
def as_text(self):
|
||||
return self.text
|
||||
|
||||
@staticmethod
|
||||
def empty(*a, **kw):
|
||||
return TokenBuffer(*a,**kw)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return TokenBuffer(*a, text=text, **kw)
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
def trim_words(self, num=1, after=0):
|
||||
'''
|
||||
num: how many words to trim from the beginning
|
||||
after: how many characters to skip (length of the static prompt)
|
||||
'''
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
|
||||
ids = tokenizer.encode(self.text[after:])
|
||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||
print(words, file=sys.stderr)
|
||||
print(wids, file=sys.stderr)
|
||||
if not words:
|
||||
return 0
|
||||
self.text = self.text[:after] + "".join(words[num:])
|
||||
return sum(len(wi) for wi in wids[:num])
|
||||
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
self.text += self.tokenizer.decode(token_ids)
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
ids = tokenizer.encode(self.text)
|
||||
return tokenizer.split_to_word_tokens(ids)
|
||||
0
whisperlivekit/web/__init__.py
Normal file
0
whisperlivekit/web/__init__.py
Normal file
@@ -38,7 +38,6 @@
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
/* Shape inside the button */
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
@@ -56,6 +55,10 @@
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
@@ -279,7 +282,7 @@
|
||||
</div>
|
||||
<div>
|
||||
<label for="websocketInput">WebSocket URL:</label>
|
||||
<input id="websocketInput" type="text" value="ws://localhost:8000/asr" />
|
||||
<input id="websocketInput" type="text" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -304,6 +307,8 @@
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
@@ -315,6 +320,13 @@
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = window.location.port || "8000";
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const defaultWebSocketUrl = `${protocol}://${host}:${port}/asr`;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
@@ -346,12 +358,31 @@
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
statusText.textContent = "WebSocket closed by user.";
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0, 0, true // isFinalizing = true
|
||||
);
|
||||
}
|
||||
}
|
||||
// If ready_to_stop was received, statusText is already "Finished processing..."
|
||||
// and waitingForStop is false.
|
||||
} else {
|
||||
statusText.textContent =
|
||||
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
userClosing = false;
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
@@ -363,12 +394,41 @@
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
// Check for status messages
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
0, // No more lag
|
||||
0, // No more lag
|
||||
true // isFinalizing = true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close(); // will trigger onclose
|
||||
// websocket = null; // onclose handle setting websocket to null
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
// Handle normal transcription updates
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription"
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
@@ -376,13 +436,20 @@
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription) {
|
||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const linesHtml = lines.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.beg !== undefined && item.end !== undefined) {
|
||||
@@ -392,30 +459,46 @@
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0) {
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker == -1) {
|
||||
speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker !== -1) {
|
||||
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
}
|
||||
|
||||
let textContent = item.text;
|
||||
if (idx === lines.length - 1) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`
|
||||
}
|
||||
if (idx === lines.length - 1 && buffer_diarization) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`
|
||||
textContent += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
if (idx === lines.length - 1) {
|
||||
textContent += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing) {
|
||||
if (remaining_time_transcription > 0) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
||||
}
|
||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return textContent
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
}).join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
@@ -494,8 +577,17 @@
|
||||
}
|
||||
}
|
||||
|
||||
function stopRecording() {
|
||||
async function stopRecording() {
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
// Send empty audio buffer as stop signal
|
||||
const emptyBlob = new Blob([], { type: 'audio/webm' });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
recorder.stop();
|
||||
recorder = null;
|
||||
@@ -531,38 +623,60 @@
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
|
||||
isRecording = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
websocket = null;
|
||||
}
|
||||
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
linesTranscriptDiv.innerHTML = "";
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return; // Early return, UI is already updated
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
// If we have an active WebSocket that's still processing, just restart audio capture
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await startRecording();
|
||||
} else {
|
||||
// If no active WebSocket or it's closed, create new one
|
||||
await setupWebSocket();
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "Recording...";
|
||||
} else {
|
||||
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed.") {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
13
whisperlivekit/web/web_interface.py
Normal file
13
whisperlivekit/web/web_interface.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
import importlib.resources as resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_web_interface_html():
|
||||
"""Loads the HTML for the web interface using importlib.resources."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading web interface HTML: {e}")
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
0
whisperlivekit/whisper_streaming_custom/__init__.py
Normal file
0
whisperlivekit/whisper_streaming_custom/__init__.py
Normal file
@@ -3,13 +3,29 @@ import logging
|
||||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
import torch
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("SimulStreaming dependencies not available. SimulStreaming backend will not be available.")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
AlignAttConfig = None
|
||||
PaddedAlignAttWhisper = None
|
||||
DEC_PAD = None
|
||||
tokenizer = None
|
||||
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
@@ -102,8 +118,9 @@ class FasterWhisperASR(ASRBase):
|
||||
model_size_or_path = modelsize
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "float32"
|
||||
device = "auto" # Allow CTranslate2 to decide available device
|
||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||
|
||||
|
||||
model = WhisperModel(
|
||||
model_size_or_path,
|
||||
@@ -249,8 +266,8 @@ class OpenaiApiASR(ASRBase):
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
if segment["no_speech_prob"] > 0.8:
|
||||
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
||||
if segment.no_speech_prob > 0.8:
|
||||
no_speech_segments.append((segment.start, segment.end))
|
||||
tokens = []
|
||||
for word in segments.words:
|
||||
start = word.start
|
||||
@@ -289,4 +306,181 @@ class OpenaiApiASR(ASRBase):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
self.task = "translate"
|
||||
|
||||
|
||||
class SimulStreamingASR(ASRBase):
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = " "
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError("""SimulStreaming dependencies are not available. Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]". If you are building from source, you should also copy the content of the simul_whisper directory from the SimulStreaming repository into whisperlivekit/simul_whisper.""")
|
||||
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
|
||||
print("*"*80 + f.read() + "*"*80)
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
|
||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
||||
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
||||
self.beams = kwargs.get('beams', 1)
|
||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
||||
self.task = kwargs.get('task', 'transcribe')
|
||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
||||
self.never_fire = kwargs.get('never_fire', False)
|
||||
self.init_prompt = kwargs.get('init_prompt', None)
|
||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
||||
|
||||
if model_dir is not None:
|
||||
self.model_path = model_dir
|
||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
'small': './small.pt',
|
||||
'medium': './medium.pt',
|
||||
'medium.en': './medium.en.pt',
|
||||
'large-v1': './large-v1.pt',
|
||||
'base.en': './base.en.pt',
|
||||
'small.en': './small.en.pt',
|
||||
'tiny.en': './tiny.en.pt',
|
||||
'large-v2': './large-v2.pt',
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
||||
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.task == "translate":
|
||||
self.set_translate_task()
|
||||
|
||||
def load_model(self, modelsize, cache_dir, model_dir):
|
||||
try:
|
||||
cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
segment_length=self.segment_length,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.original_language,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.task,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
||||
model = PaddedAlignAttWhisper(cfg)
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
||||
raise
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
"""Transcribe audio using SimulStreaming."""
|
||||
try:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
else:
|
||||
audio_tensor = audio
|
||||
|
||||
prompt = init_prompt if init_prompt else (self.init_prompt or "")
|
||||
|
||||
result = self.model.infer(audio_tensor, init_prompt=prompt)
|
||||
|
||||
if torch.is_tensor(result):
|
||||
result = result[result < DEC_PAD]
|
||||
|
||||
logger.debug(f"SimulStreaming transcription result: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def ts_words(self, result) -> List[ASRToken]:
|
||||
"""Convert SimulStreaming result to ASRToken list."""
|
||||
tokens = []
|
||||
|
||||
try:
|
||||
if torch.is_tensor(result):
|
||||
text = self.model.tokenizer.decode(result.cpu().numpy())
|
||||
else:
|
||||
text = str(result)
|
||||
|
||||
if not text or len(text.strip()) == 0:
|
||||
return tokens
|
||||
|
||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
||||
words = text.strip().split()
|
||||
if not words:
|
||||
return tokens
|
||||
|
||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
||||
#with the SimulStreamingOnlineProcessor
|
||||
|
||||
for i, word in enumerate(words):
|
||||
start_time = i * duration_per_word
|
||||
end_time = (i + 1) * duration_per_word
|
||||
|
||||
token = ASRToken(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=word,
|
||||
probability=1.0
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
|
||||
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, result) -> List[float]:
|
||||
"""Get segment end timestamps."""
|
||||
if torch.is_tensor(result):
|
||||
num_tokens = len(result)
|
||||
return [num_tokens * 0.1] # rough estimate
|
||||
return [1.0]
|
||||
|
||||
def use_vad(self):
|
||||
"""Enable VAD - SimulStreaming has different VAD handling."""
|
||||
logger.info("VAD requested for SimulStreaming - handled internally by the model")
|
||||
pass
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
try:
|
||||
self.model.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=True,
|
||||
language=self.model.cfg.language,
|
||||
num_languages=self.model.model.num_languages,
|
||||
task="translate"
|
||||
)
|
||||
logger.info("SimulStreaming configured for translation task")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
|
||||
raise
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio).float()
|
||||
self.model.infer(audio, True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"SimulStreaming warmup failed: {e}")
|
||||
|
||||
@@ -6,6 +6,17 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# simulStreaming imports - we check if the files are here
|
||||
try:
|
||||
import torch
|
||||
from simul_whisper.config import AlignAttConfig
|
||||
SIMULSTREAMING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
||||
SIMULSTREAMING_AVAILABLE = False
|
||||
OnlineProcessorInterface = None
|
||||
torch = None
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
@@ -144,7 +155,11 @@ class OnlineASRProcessor:
|
||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||
self.committed: List[ASRToken] = []
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray):
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio_buffer."""
|
||||
return self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
@@ -179,18 +194,19 @@ class OnlineASRProcessor:
|
||||
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||
|
||||
|
||||
def process_iter(self) -> Transcript:
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Processes the current audio buffer.
|
||||
|
||||
Returns a Transcript object representing the committed transcript.
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
current_audio_processed_upto = self.get_audio_buffer_end_time()
|
||||
prompt_text, _ = self.prompt()
|
||||
logger.debug(
|
||||
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
||||
)
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
||||
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
|
||||
tokens = self.asr.ts_words(res)
|
||||
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||
committed_tokens = self.transcript_buffer.flush()
|
||||
self.committed.extend(committed_tokens)
|
||||
@@ -210,37 +226,60 @@ class OnlineASRProcessor:
|
||||
logger.debug(
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
return committed_tokens
|
||||
return committed_tokens, current_audio_processed_upto
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
"""
|
||||
If the committed tokens form at least two sentences, chunk the audio
|
||||
buffer at the end time of the penultimate sentence.
|
||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||
"""
|
||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if not self.committed:
|
||||
if buffer_duration > self.buffer_trimming_sec:
|
||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
return
|
||||
|
||||
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
||||
sentences = self.words_to_sentences(self.committed)
|
||||
for sentence in sentences:
|
||||
logger.debug(f"\tSentence: {sentence.text}")
|
||||
if len(sentences) < 2:
|
||||
return
|
||||
# Keep the last two sentences.
|
||||
while len(sentences) > 2:
|
||||
sentences.pop(0)
|
||||
chunk_time = sentences[-2].end
|
||||
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
|
||||
chunk_done = False
|
||||
if len(sentences) >= 2:
|
||||
while len(sentences) > 2:
|
||||
sentences.pop(0)
|
||||
chunk_time = sentences[-2].end
|
||||
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
chunk_done = True
|
||||
|
||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||
last_committed_time = self.committed[-1].end
|
||||
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
|
||||
self.chunk_at(last_committed_time)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
"""
|
||||
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||
"""
|
||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if not self.committed:
|
||||
if buffer_duration > self.buffer_trimming_sec:
|
||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
return
|
||||
|
||||
logger.debug("Processing committed tokens for segmenting")
|
||||
ends = self.asr.segments_end_ts(res)
|
||||
last_committed_time = self.committed[-1].end
|
||||
last_committed_time = self.committed[-1].end
|
||||
chunk_done = False
|
||||
if len(ends) > 1:
|
||||
logger.debug("Multiple segments available for chunking")
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
while len(ends) > 2 and e > last_committed_time:
|
||||
ends.pop(-1)
|
||||
@@ -248,11 +287,18 @@ class OnlineASRProcessor:
|
||||
if e <= last_committed_time:
|
||||
logger.debug(f"--- Segment chunked at {e:.2f}")
|
||||
self.chunk_at(e)
|
||||
chunk_done = True
|
||||
else:
|
||||
logger.debug("--- Last segment not within committed area")
|
||||
else:
|
||||
logger.debug("--- Not enough segments to chunk")
|
||||
|
||||
|
||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
|
||||
self.chunk_at(last_committed_time)
|
||||
|
||||
logger.debug("Segment chunking complete")
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
Trim both the hypothesis and audio buffer at the given time.
|
||||
@@ -313,15 +359,17 @@ class OnlineASRProcessor:
|
||||
)
|
||||
sentences.append(sentence)
|
||||
return sentences
|
||||
def finish(self) -> Transcript:
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Flush the remaining transcript when processing ends.
|
||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||
"""
|
||||
remaining_tokens = self.transcript_buffer.buffer
|
||||
final_transcript = self.concatenate_tokens(remaining_tokens)
|
||||
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
||||
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
return final_transcript
|
||||
logger.debug(f"Final non-committed tokens: {remaining_tokens}")
|
||||
final_processed_upto = self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||
self.buffer_time_offset = final_processed_upto
|
||||
return remaining_tokens, final_processed_upto
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
@@ -354,36 +402,44 @@ class VACOnlineASRProcessor:
|
||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
||||
self.online_chunk_size = online_chunk_size
|
||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
||||
|
||||
self.asr = self.online.asr
|
||||
|
||||
# Load a VAD model (e.g. Silero VAD)
|
||||
import torch
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from silero_vad_iterator import FixedVADIterator
|
||||
from .silero_vad_iterator import FixedVADIterator
|
||||
|
||||
self.vac = FixedVADIterator(model)
|
||||
self.logfile = self.online.logfile
|
||||
self.last_input_audio_stream_end_time: float = 0.0
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.online.init()
|
||||
self.vac.reset_states()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
||||
self.is_currently_final = False
|
||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.buffer_offset = 0 # in frames
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
||||
return self.online.get_audio_buffer_end_time()
|
||||
|
||||
def clear_buffer(self):
|
||||
self.buffer_offset += len(self.audio_buffer)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray):
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
"""
|
||||
Process an incoming small audio chunk:
|
||||
- run VAD on the chunk,
|
||||
- decide whether to send the audio to the online ASR processor immediately,
|
||||
- and/or to mark the current utterance as finished.
|
||||
"""
|
||||
self.last_input_audio_stream_end_time = audio_stream_end_time
|
||||
res = self.vac(audio)
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
@@ -425,10 +481,11 @@ class VACOnlineASRProcessor:
|
||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
||||
|
||||
def process_iter(self) -> Transcript:
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Depending on the VAD status and the amount of accumulated audio,
|
||||
process the current audio chunk.
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if self.is_currently_final:
|
||||
return self.finish()
|
||||
@@ -437,17 +494,224 @@ class VACOnlineASRProcessor:
|
||||
return self.online.process_iter()
|
||||
else:
|
||||
logger.debug("No online update, only VAD")
|
||||
return Transcript(None, None, "")
|
||||
return [], self.last_input_audio_stream_end_time
|
||||
|
||||
def finish(self) -> Transcript:
|
||||
"""Finish processing by flushing any remaining text."""
|
||||
result = self.online.finish()
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Finish processing by flushing any remaining text.
|
||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||
"""
|
||||
result_tokens, processed_upto = self.online.finish()
|
||||
self.current_online_chunk_buffer_size = 0
|
||||
self.is_currently_final = False
|
||||
return result
|
||||
return result_tokens, processed_upto
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer in string format.
|
||||
"""
|
||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
|
||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
||||
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
tokenize_method: Optional[callable] = None,
|
||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
||||
confidence_validation = False,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError("SimulStreaming dependencies are not available.")
|
||||
|
||||
self.asr = asr
|
||||
self.tokenize = tokenize_method
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = confidence_validation
|
||||
self.init()
|
||||
|
||||
# buffer does not work yet
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing state."""
|
||||
self.audio_chunks = []
|
||||
self.offset = offset if offset is not None else 0.0
|
||||
self.is_last = False
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
self.cumulative_audio_duration = 0.0
|
||||
self.last_audio_stream_end_time = self.offset
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio buffer."""
|
||||
return self.end
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.audio_chunks.append(audio_tensor)
|
||||
|
||||
# Update timing
|
||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||
self.cumulative_audio_duration += chunk_duration
|
||||
|
||||
if audio_stream_end_time is not None:
|
||||
self.last_audio_stream_end_time = audio_stream_end_time
|
||||
self.end = audio_stream_end_time
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context).
|
||||
SimulStreaming handles prompting internally, so we return empty strings.
|
||||
"""
|
||||
return "", ""
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer content.
|
||||
"""
|
||||
buffer_end = self.end if hasattr(self, 'end') else None
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=buffer_end,
|
||||
text=self.buffer_content,
|
||||
probability=None
|
||||
)
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
if not self.audio_chunks:
|
||||
return [], self.end
|
||||
|
||||
try:
|
||||
# concatenate all audio chunks
|
||||
if len(self.audio_chunks) == 1:
|
||||
audio = self.audio_chunks[0]
|
||||
else:
|
||||
audio = torch.cat(self.audio_chunks, dim=0)
|
||||
|
||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
||||
self.processed_audio_duration += audio_duration
|
||||
|
||||
self.audio_chunks = []
|
||||
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
result = self.asr.model.infer(audio, is_last=self.is_last)
|
||||
|
||||
if torch.is_tensor(result):
|
||||
# we filter out padding tokens as it s done in simul whisper
|
||||
from simul_whisper.simul_whisper import DEC_PAD
|
||||
result = result[result < DEC_PAD]
|
||||
|
||||
# C/P from simul_whisper.simul_whisper.py
|
||||
if len(result) > 0:
|
||||
decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy())
|
||||
logger.debug(f"SimulStreaming decoded: {decoded_text}")
|
||||
|
||||
if decoded_text.strip():
|
||||
words = decoded_text.strip().split()
|
||||
new_tokens = []
|
||||
|
||||
num_words = len(words)
|
||||
if num_words > 0:
|
||||
# distribute words evenly across the processed audio duration
|
||||
# we NEED that for when we use diarization. Even if that s not perfect
|
||||
start_time = self.end - audio_duration
|
||||
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
|
||||
|
||||
for i, word in enumerate(words):
|
||||
token_start = start_time + (i * time_per_word)
|
||||
token_end = start_time + ((i + 1) * time_per_word)
|
||||
|
||||
token_end = min(token_end, self.end)
|
||||
|
||||
token = ASRToken(
|
||||
start=token_start,
|
||||
end=token_end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
self.beg = self.end
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
self.last_result_tokens = new_tokens
|
||||
|
||||
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
|
||||
return new_tokens, self.end
|
||||
|
||||
return [], self.end
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SimulStreaming processing error: {e}")
|
||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
||||
return [], self.end
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
logger.debug("SimulStreaming finish() called")
|
||||
self.is_last = True
|
||||
final_tokens, final_time = self.process_iter()
|
||||
self.is_last = False
|
||||
return final_tokens, final_time
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
"""Concatenate tokens into a Transcript object."""
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
useless but kept for compatibility
|
||||
"""
|
||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
||||
pass
|
||||
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
Create simple sentences.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
sentence = Sentence(
|
||||
start=tokens[0].start,
|
||||
end=tokens[-1].end,
|
||||
text=full_text
|
||||
)
|
||||
return [sentence]
|
||||
|
||||
@@ -5,8 +5,8 @@ import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE
|
||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,6 +69,37 @@ def backend_factory(args):
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
elif backend == "simulstreaming":
|
||||
logger.debug("Using SimulStreaming backend.")
|
||||
if not SIMULSTREAMING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"SimulStreaming backend is not available. Please install SimulStreaming dependencies. "
|
||||
"See the documentation for installation instructions."
|
||||
)
|
||||
|
||||
simulstreaming_kwargs = {}
|
||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
||||
'max_context_tokens', 'model_path']:
|
||||
if hasattr(args, attr):
|
||||
simulstreaming_kwargs[attr] = getattr(args, attr)
|
||||
|
||||
# Add segment_length from min_chunk_size
|
||||
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
|
||||
simulstreaming_kwargs['task'] = args.task
|
||||
|
||||
size = args.model
|
||||
t = time.time()
|
||||
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
|
||||
asr = SimulStreamingASR(
|
||||
modelsize=size,
|
||||
lan=args.lan,
|
||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
||||
model_dir=getattr(args, 'model_dir', None),
|
||||
**simulstreaming_kwargs
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
@@ -84,8 +115,8 @@ def backend_factory(args):
|
||||
asr = asr_cls(
|
||||
modelsize=size,
|
||||
lan=args.lan,
|
||||
cache_dir=args.model_cache_dir,
|
||||
model_dir=args.model_dir,
|
||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
||||
model_dir=getattr(args, 'model_dir', None),
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
@@ -97,21 +128,33 @@ def backend_factory(args):
|
||||
|
||||
language = args.lan
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
if backend != "simulstreaming":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if args.buffer_trimming == "sentence":
|
||||
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
return asr, tokenizer
|
||||
|
||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
if args.vac:
|
||||
if args.backend == "simulstreaming":
|
||||
if not SIMULSTREAMING_ONLINE_AVAILABLE:
|
||||
raise ImportError("SimulStreaming online processor is not available.")
|
||||
|
||||
logger.debug("Creating SimulStreaming online processor")
|
||||
online = SimulStreamingOnlineProcessor(
|
||||
asr,
|
||||
tokenizer,
|
||||
logfile=logfile,
|
||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
||||
confidence_validation=args.confidence_validation
|
||||
)
|
||||
elif args.vac:
|
||||
online = VACOnlineASRProcessor(
|
||||
args.min_chunk_size,
|
||||
asr,
|
||||
@@ -145,6 +188,7 @@ def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
|
||||
|
||||
if warmup_file is None:
|
||||
# Download JFK sample if not already present
|
||||
@@ -179,16 +223,23 @@ def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
||||
return False
|
||||
|
||||
print(f"Warmping up Whisper with {warmup_file}")
|
||||
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
|
||||
try:
|
||||
import librosa
|
||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load audio file: {e}")
|
||||
return False
|
||||
|
||||
# Process the audio
|
||||
asr.transcribe(audio)
|
||||
|
||||
logger.info("Whisper is warmed up")
|
||||
|
||||
|
||||
try:
|
||||
if is_simulstreaming:
|
||||
asr.warmup(audio)
|
||||
else:
|
||||
asr.transcribe(audio)
|
||||
|
||||
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup failed: {e}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user