34 Commits
0.2.6 ... 0.2.8

Author SHA1 Message Date
Quentin Fuxa
3bd2122eb4 0.2.8 : only the decoder of whisper is loaded in memory when a different encoder is used 2025-09-02 21:12:25 +02:00
Quentin Fuxa
50b0527858 update architecture 2025-09-01 21:24:12 +02:00
Quentin Fuxa
b044fcdec2 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-09-01 14:55:19 +02:00
Quentin Fuxa
b0508fcf2c mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 14:55:11 +02:00
Quentin Fuxa
ce89b0aebc Merge pull request #177 from komiyamma/translate-readme-to-japanese
Translate README.md to Japanese
2025-09-01 13:54:50 +02:00
Quentin Fuxa
d5008ed828 mlx/fasterWhisper encoders are loaded once and shared in simulstreaming 2025-09-01 12:33:19 +02:00
Quentin Fuxa
d467716e26 add microphone picker 2025-08-31 10:12:52 +02:00
Quentin Fuxa
199e21b3ef faster-whisper as an optional encoder alternative for simulstreaming 2025-08-30 23:50:16 +02:00
Quentin Fuxa
1d926f2e67 mlx-whisper used as simulstreaming encoder: improve speed for macos systems 2025-08-30 22:19:11 +02:00
Quentin Fuxa
4a71a391b8 get_web_interface_html to get_inline_ui_html for embedded web interface HTML 2025-08-30 13:44:06 +02:00
google-labs-jules[bot]
d3ed4e46e2 Translate README.md to Japanese
Create a Japanese version of the README.md file named ReadmeJP.md.
This makes the project more accessible to Japanese-speaking users.
2025-08-30 04:16:18 +00:00
Quentin Fuxa
057a1026d7 Merge branch 'main' of https://github.com/QuentinFuxa/whisper_streaming_web 2025-08-29 22:01:04 +02:00
Quentin Fuxa
1ba171a58d add embedded web interface HTML (single-file version with inline CSS/JS/SVG)
### Added
- `get_inline_ui_html()`: generates a self-contained version of the web interface, with CSS, JS, and SVG assets inlined directly into the HTML. useful for environments where serving static files is inconvenient or when a single-call UI delivery is preferred.

(cherry picked from commit aa44a92a67)
2025-08-29 22:00:59 +02:00
Quentin Fuxa
1adac67155 explanations about model persistency in containers 2025-08-29 21:27:08 +02:00
Quentin Fuxa
42be1a3773 Merge pull request #173 from CoderRahul9904/chore/docker/pytorch-timeout-retries
fix: increase pip timeout & retries for torch wheel install
2025-08-29 21:22:30 +02:00
Rahul Mourya
0a49fafa0d Update Dockerfile
fix(docker): increase pip timeout/retries for PyTorch wheel installs
2025-08-30 00:23:59 +05:30
Quentin Fuxa
4a5d5e1f3b raise Exception when language == auto and task == translation 2025-08-29 17:44:46 +02:00
Quentin Fuxa
583a2ec2e4 highlight Sortformer optional installation 2025-08-27 21:02:25 +02:00
Quentin Fuxa
19765e89e9 remove triton <3 condition 2025-08-27 20:44:39 +02:00
Quentin Fuxa
9895bc83bf auto detection of language for warmup if not indicated 2025-08-27 20:37:48 +02:00
Quentin Fuxa
ab98c31f16 trim will happen before audio processor 2025-08-27 18:17:11 +02:00
Quentin Fuxa
f9c9c4188a optional dependencies removed, ask to direct alternative package installations 2025-08-27 18:15:32 +02:00
Quentin Fuxa
c21d2302e7 to 0.2.7 2024-08-24 19:28:00 +02:00
Quentin Fuxa
4ed62e181d when silences are detected, speaker correction is no more applied 2024-08-24 19:24:00 +02:00
Quentin Fuxa
52a755a08c indications on how to choose a model 2024-08-24 19:22:00 +02:00
Quentin Fuxa
9a8d3cbd90 improve diarization + silence handling 2024-08-24 19:20:00 +02:00
Quentin Fuxa
b101ce06bd several users share the same sortformer model instance 2024-08-24 19:18:00 +02:00
Quentin Fuxa
c83fd179a8 improves phase shift correction between transcription and diarization 2024-08-24 19:15:00 +02:00
Quentin Fuxa
5258305745 default diarization backend in now sortformer 2025-08-24 18:32:01 +02:00
Quentin Fuxa
ce781831ee punctuation is checked in audio-processor's result formatter 2025-08-24 18:32:01 +02:00
Quentin Fuxa
58297daf6d sortformer diar implementation v0.3 2025-08-24 18:32:01 +02:00
Quentin Fuxa
3393a08f7e sortformer diar implementation v0.2 2025-08-24 18:32:01 +02:00
Quentin Fuxa
5b2ddeccdb correct pip installation error in image build 2025-08-22 15:37:46 +02:00
Quentin Fuxa
26cc1072dd new dockerfile for cpu only. update dockerfile from cuda 12.8 to 12.9 2025-08-22 11:04:35 +02:00
29 changed files with 1895 additions and 631 deletions

70
DEV_NOTES.md Normal file
View File

@@ -0,0 +1,70 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
Memory saved by only loading encoder for optimized framework:
For tiny.en, mlx whisper:
Sizes MLX whisper:
Decoder weights: 59110771 bytes
Encoder weights: 15268874 bytes
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04 FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
@@ -9,48 +9,50 @@ ARG EXTRAS
ARG HF_PRECACHE_DIR ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE ARG HF_TKN_FILE
# Install system dependencies
#RUN apt-get update && \
# apt-get install -y ffmpeg git && \
# apt-get clean && \
# rm -rf /var/lib/apt/lists/*
# 2) Install system dependencies + Python + pip
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
python3 \ python3 \
python3-pip \ python3-pip \
python3-venv \
ffmpeg \ ffmpeg \
git \ git \
build-essential \ build-essential \
python3-dev && \ python3-dev \
ca-certificates && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# timeout/retries for large torch wheels
RUN pip3 install --upgrade pip setuptools wheel && \
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchaudio \
|| (echo "Initial install failed — retrying with extended timeout..." && \
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
--index-url https://download.pytorch.org/whl/cu129 \
torch torchvision torchaudio)
COPY . . COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies # Install WhisperLiveKit directly, allowing for optional dependencies
# Note: For gates models, need to add your HF toke. See README.md
# for more details.
RUN if [ -n "$EXTRAS" ]; then \ RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \ echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir .[$EXTRAS]; \ pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \ else \
echo "Installing base package only"; \ echo "Installing base package only"; \
pip install --no-cache-dir .; \ pip install --no-cache-dir whisperlivekit; \
fi fi
# Enable in-container caching for Hugging Face models by: # In-container caching for Hugging Face models by:
# Note: If running multiple containers, better to map a shared
# bucket.
#
# A) Make the cache directory persistent via an anonymous volume. # A) Make the cache directory persistent via an anonymous volume.
# Note: This only persists for a single, named container. This is # Note: This only persists for a single, named container. This is
# only for convenience at de/test stage. # only for convenience at de/test stage.
# For prod, it is better to use a named volume via host mount/k8s. # For prod, it is better to use a named volume via host mount/k8s.
VOLUME ["/root/.cache/huggingface/hub"] VOLUME ["/root/.cache/huggingface/hub"]
# or # or
# B) Conditionally copy a local pre-cache from the build context to the # B) Conditionally copy a local pre-cache from the build context to the
# container's cache via the HF_PRECACHE_DIR build-arg. # container's cache via the HF_PRECACHE_DIR build-arg.
@@ -65,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "No local Hugging Face cache specified, skipping copy"; \ echo "No local Hugging Face cache specified, skipping copy"; \
fi fi
# Conditionally copy a Hugging Face token if provided # Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
RUN if [ -n "$HF_TKN_FILE" ]; then \ RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \ echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \ mkdir -p /root/.cache/huggingface && \
@@ -74,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
else \ else \
echo "No Hugging Face token file specified, skipping token setup"; \ echo "No Hugging Face token file specified, skipping token setup"; \
fi fi
# Expose port for the transcription server
EXPOSE 8000 EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"] ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args CMD ["--model", "medium"]
CMD ["--model", "base"]

61
Dockerfile.cpu Normal file
View File

@@ -0,0 +1,61 @@
FROM python:3.13-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"]

117
README.md
View File

@@ -8,8 +8,8 @@
<p align="center"> <p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p> </p>
@@ -66,41 +66,31 @@ pip install whisperlivekit
| Optional | `pip install` | | Optional | `pip install` |
|-----------|-------------| |-----------|-------------|
| Speaker diarization | `whisperlivekit[diarization]` | | **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Original Whisper backend | `whisperlivekit[whisper]` | | **Apple Silicon optimized backend** | `mlx-whisper` |
| Improved timestamps backend | `whisperlivekit[whisper-timestamped]` | | *[Not recommanded]* Speaker diarization with Diart | `diart` |
| Apple Silicon optimization backend | `whisperlivekit[mlx-whisper]` | | *[Not recommanded]* Original Whisper backend | `whisper` |
| OpenAI API backend | `whisperlivekit[openai]` | | *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them. See **Parameters & Configuration** below on how to use them.
> **Pyannote Models Setup** For diarization, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace:
> ```bash
> huggingface-cli login
> ```
## 💻 Usage Examples
#### Command-line Interface ### Usage Examples
Start the transcription server with various options: **Command-line Interface**: Start the transcription server with various options:
```bash ```bash
# SimulStreaming backend for ultra-low latency # Use better model than default (small)
whisperlivekit-server --backend simulstreaming --model large-v3 whisperlivekit-server --model large-v3
# Advanced configuration with diarization # Advanced configuration with diarization and language
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
``` ```
#### Python API Integration (Backend) **Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
```python ```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
@@ -138,17 +128,26 @@ async def websocket_endpoint(websocket: WebSocket):
await audio_processor.process_audio(message) await audio_processor.process_audio(message)
``` ```
#### Frontend Implementation **Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
### ⚙️ Parameters & Configuration ## Parameters & Configuration
An important list of parameters can be changed. But what *should* you change?
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--task translate`, to translate in english
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
The rest I don't recommend. But below are your options.
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. | `small` | | `--model` | Whisper model size. | `small` |
| `--language` | Source language code or `auto` | `en` | | `--language` | Source language code or `auto` | `auto` |
| `--task` | `transcribe` or `translate` | `transcribe` | | `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `simulstreaming` | | `--backend` | Processing backend | `simulstreaming` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` | | `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
@@ -161,14 +160,9 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| SimulStreaming backend options | Description | Default | | SimulStreaming backend options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -182,12 +176,25 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` | | `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | | `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| Diarization options | Description | Default | | Diarization options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` | | `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` | | `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` | | `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> For diarization using Diart, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace: `huggingface-cli login`
### 🚀 Deployment Guide ### 🚀 Deployment Guide
@@ -216,19 +223,39 @@ To deploy WhisperLiveKit in production:
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL 4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
### 🐋 Docker ## 🐋 Docker
A Dockerfile is provided which allows re-use of Python package installation options. Create a reusable image with only the basics and then run as a named container: Deploy the application easily using Docker with GPU or CPU support.
### Prerequisites
- Docker installed on your system
- For GPU support: NVIDIA Docker runtime installed
### Quick Start
**With GPU acceleration (recommended):**
```bash ```bash
docker build -t whisperlivekit-defaults . docker build -t wlk .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults --model base docker run --gpus all -p 8000:8000 --name wlk wlk
docker start -i whisperlivekit
``` ```
> **Note**: For **large** models, ensure that your **docker runtime** has enough **memory** available **CPU only:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### Advanced Usage
**Custom configuration:**
```bash
# Example with custom model and language
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
#### Customization #### Customization

258
ReadmeJP.md Normal file
View File

@@ -0,0 +1,258 @@
<h1 align="center">WhisperLiveKit</h1>
<p align="center">
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
</p>
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
#### 主要な研究による技術:
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
### アーキテクチャ
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
### インストールとクイックスタート
```bash
pip install whisperlivekit
```
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
>
> | OS | インストール方法 |
> |-----------|-------------|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
> | MacOS | `brew install ffmpeg` |
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
#### クイックスタート
1. **文字起こしサーバーを起動します:**
```bash
whisperlivekit-server --model base --language en
```
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
#### オプションの依存関係
| オプション | `pip install` |
|-----------|-------------|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Diartによる話者ダイアライゼーション | `diart` |
| オリジナルのWhisperバックエンド | `whisper` |
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
| Apple Silicon最適化バックエンド | `mlx-whisper` |
| OpenAI APIバックエンド | `openai` |
それらの使用方法については、以下の**パラメータと設定**を参照してください。
### 使用例
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
```bash
# デフォルト(small)より良いモデルを使用
whisperlivekit-server --model large-v3
# ダイアライゼーションと言語を指定した高度な設定
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
import asyncio
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
yield
app = FastAPI(lifespan=lifespan)
async def handle_websocket_results(websocket: WebSocket, results_generator):
async for response in results_generator:
await websocket.send_json(response)
await websocket.send_json({"type": "ready_to_stop"})
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
results_generator = await audio_processor.create_tasks()
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
await websocket.accept()
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
```
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
## パラメータと設定
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
- `--backend` `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
- `--warmup-file`、もしあれば
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
- `--diarization`、使用したい場合。
残りは推奨しません。しかし、以下があなたのオプションです。
| パラメータ | 説明 | デフォルト |
|-----------|-------------|---------|
| `--model` | Whisperモデルのサイズ。 | `small` |
| `--language` | ソース言語コードまたは`auto` | `auto` |
| `--task` | `transcribe`または`translate` | `transcribe` |
| `--backend` | 処理バックエンド | `simulstreaming` |
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
| `--no-vad` | 音声区間検出を無効化 | `False` |
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
| `--host` | サーバーホストアドレス | `localhost` |
| `--port` | サーバーポート | `8000` |
| `--ssl-certfile` | SSL証明書ファイルへのパスHTTPSサポート用 | `None` |
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパスHTTPSサポート用 | `None` |
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment` | `segment` |
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--frame-threshold` | AlignAttフレームしきい値低いほど速く、高いほど正確 | `25` |
| `--beams` | ビームサーチのビーム数1 = 貪欲デコーディング) | `1` |
| `--decoder` | デコーダタイプを強制(`beam`または`greedy` | `auto` |
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
| `--init-prompt` | モデルの初期プロンプト | `None` |
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
| ダイアライゼーションオプション | 説明 | デフォルト |
|-----------|-------------|---------|
| `--diarization` | 話者識別を有効化 | `False` |
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
>4. HuggingFaceでログイン: `huggingface-cli login`
### 🚀 デプロイガイド
WhisperLiveKitを本番環境にデプロイするには
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
```bash
pip install uvicorn gunicorn
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
```
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
3. **Nginx設定** (本番環境で推奨):
```nginx
server {
listen 80;
server_name your-domain.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}}
```
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
## 🐋 Docker
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
### 前提条件
- Dockerがシステムにインストールされていること
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
### クイックスタート
**GPUアクセラレーション付き (推奨):**
```bash
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
**CPUのみ:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### 高度な使用法
**カスタム設定:**
```bash
# カスタムモデルと言語の例
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### メモリ要件
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
#### カスタマイズ
- `--build-arg` オプション:
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
## 🔮 ユースケース
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 388 KiB

After

Width:  |  Height:  |  Size: 368 KiB

72
available_models.md Normal file
View File

@@ -0,0 +1,72 @@
# Available model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)

BIN
demo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 423 KiB

After

Width:  |  Height:  |  Size: 449 KiB

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.6" version = "0.2.8"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
{ name = "Quentin Fuxa" } { name = "Quentin Fuxa" }
@@ -18,6 +18,11 @@ classifiers = [
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech" "Topic :: Multimedia :: Sound/Audio :: Speech"
] ]
@@ -28,19 +33,15 @@ dependencies = [
"faster-whisper", "faster-whisper",
"uvicorn", "uvicorn",
"websockets", "websockets",
"torch", "torchaudio>=2.0.0",
"torch>=2.0.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
] ]
[project.optional-dependencies] [project.optional-dependencies]
diarization = ["diart"]
sentence = ["mosestokenizer", "wtpsplit"] sentence = ["mosestokenizer", "wtpsplit"]
whisper = ["whisper"]
whisper-timestamped = ["whisper-timestamped"]
mlx-whisper = ["mlx-whisper"]
openai = ["openai"]
[project.urls] [project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"

View File

@@ -1,12 +1,13 @@
from .audio_processor import AudioProcessor from .audio_processor import AudioProcessor
from .core import TranscriptionEngine from .core import TranscriptionEngine
from .parse_args import parse_args from .parse_args import parse_args
from .web.web_interface import get_web_interface_html from .web.web_interface import get_web_interface_html, get_inline_ui_html
__all__ = [ __all__ = [
"TranscriptionEngine", "TranscriptionEngine",
"AudioProcessor", "AudioProcessor",
"parse_args", "parse_args",
"get_web_interface_html", "get_web_interface_html",
"get_inline_ui_html",
"download_simulstreaming_backend", "download_simulstreaming_backend",
] ]

View File

@@ -4,13 +4,11 @@ from time import time, sleep
import math import math
import logging import logging
import traceback import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken, Silence from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.trail_repetition import trim_tail_repetition
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output, format_time
# Set up logging once # Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,10 +16,6 @@ logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker SENTINEL = object() # unique sentinel object for end of stream marker
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
class AudioProcessor: class AudioProcessor:
""" """
Processes audio streams for transcription and diarization. Processes audio streams for transcription and diarization.
@@ -66,7 +60,6 @@ class AudioProcessor:
# Models and processing # Models and processing
self.asr = models.asr self.asr = models.asr
self.tokenizer = models.tokenizer self.tokenizer = models.tokenizer
self.diarization = models.diarization
self.vac_model = models.vac_model self.vac_model = models.vac_model
if self.args.vac: if self.args.vac:
self.vac = FixedVADIterator(models.vac_model) self.vac = FixedVADIterator(models.vac_model)
@@ -99,6 +92,11 @@ class AudioProcessor:
# Initialize transcription engine if enabled # Initialize transcription engine if enabled
if self.args.transcription: if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer) self.online = online_factory(self.args, models.asr, models.tokenizer)
# Initialize diarization engine if enabled
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
def convert_pcm_to_float(self, pcm_buffer): def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
@@ -108,17 +106,6 @@ class AudioProcessor:
"""Thread-safe update of transcription with new data.""" """Thread-safe update of transcription with new data."""
async with self.lock: async with self.lock:
self.tokens.extend(new_tokens) self.tokens.extend(new_tokens)
# self.tokens, has_been_trimmed = trim_tail_repetition(
# self.tokens,
# key=lambda t: t.text.strip().lower(),
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
# max_tail=200,
# prefer="longest", # prefer removing the longest repeated phrase
# keep=1
# )
# if has_been_trimmed:
# print('HAS BEEN TRIMMED !')
self.buffer_transcription = buffer self.buffer_transcription = buffer
self.end_buffer = end_buffer self.end_buffer = end_buffer
self.sep = sep self.sep = sep
@@ -133,7 +120,7 @@ class AudioProcessor:
async def add_dummy_token(self): async def add_dummy_token(self):
"""Placeholder token when no transcription is available.""" """Placeholder token when no transcription is available."""
async with self.lock: async with self.lock:
current_time = time() - self.beg_loop current_time = time() - self.beg_loop if self.beg_loop else 0
self.tokens.append(ASRToken( self.tokens.append(ASRToken(
start=current_time, end=current_time + 1, start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True text=".", speaker=-1, is_dummy=True
@@ -303,12 +290,12 @@ class AudioProcessor:
if type(item) is Silence: if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s" asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens: if self.tokens:
asr_processing_logs += " | last_end = {self.tokens[-1].end} |" asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs) logger.info(asr_processing_logs)
if type(item) is Silence: if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end) self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue continue
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
@@ -433,7 +420,7 @@ class AudioProcessor:
buffer_diarization = state["buffer_diarization"] buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"] end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"] sep = state["sep"]
# Add dummy tokens if needed # Add dummy tokens if needed
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token() await self.add_dummy_token()
@@ -442,45 +429,13 @@ class AudioProcessor:
tokens = state["tokens"] tokens = state["tokens"]
# Format output # Format output
previous_speaker = -1 lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
lines = [] state,
last_end_diarized = 0 self.silence,
undiarized_text = [] current_time = time() - self.beg_loop if self.beg_loop else None,
current_time = time() - self.beg_loop if self.beg_loop else None diarization = self.args.diarization,
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence) debug = self.debug
for token in tokens: )
speaker = token.speaker
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
# Handle diarization
if self.args.diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if self.debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
lines.append({
"speaker": speaker,
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
# Handle undiarized text # Handle undiarized text
if undiarized_text: if undiarized_text:
combined = sep.join(undiarized_text) combined = sep.join(undiarized_text)
@@ -510,7 +465,7 @@ class AudioProcessor:
"buffer_transcription": buffer_transcription, "buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization, "buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"], "remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"] "remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
} }
current_response_signature = f"{response_status} | " + \ current_response_signature = f"{response_status} | " + \

View File

@@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
@@ -19,6 +19,15 @@ transcription_engine = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
#to remove after 0.2.8
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
logger.warning(f"""
{'='*50}
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
{'='*50}
""")
global transcription_engine global transcription_engine
transcription_engine = TranscriptionEngine( transcription_engine = TranscriptionEngine(
**vars(args), **vars(args),
@@ -38,7 +47,7 @@ app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/") @app.get("/")
async def get(): async def get():
return HTMLResponse(get_web_interface_html()) return HTMLResponse(get_inline_ui_html())
async def handle_websocket_results(websocket, results_generator): async def handle_websocket_results(websocket, results_generator):
@@ -52,7 +61,7 @@ async def handle_websocket_results(websocket, results_generator):
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).") logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e: except Exception as e:
logger.error(f"Error in WebSocket results handler: {e}") logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr") @app.websocket("/asr")

View File

@@ -46,6 +46,7 @@ class TranscriptionEngine:
"confidence_validation": False, "confidence_validation": False,
"buffer_trimming_sec": 15, "buffer_trimming_sec": 15,
# simulstreaming params: # simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25, "frame_threshold": 25,
"beams": 1, "beams": 1,
"decoder_type": None, "decoder_type": None,
@@ -57,10 +58,10 @@ class TranscriptionEngine:
"static_init_prompt": None, "static_init_prompt": None,
"max_context_tokens": None, "max_context_tokens": None,
"model_path": './base.pt', "model_path": './base.pt',
"diarization_backend": "diart", "diarization_backend": "sortformer",
# diart params: # diart params:
"segmentation_model": "pyannote/segmentation-3.0", "segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding", "embedding_model": "pyannote/embedding",
} }
config_dict = {**defaults, **kwargs} config_dict = {**defaults, **kwargs}
@@ -97,7 +98,7 @@ class TranscriptionEngine:
simulstreaming_kwargs = {} simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']: 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr): if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr) simulstreaming_kwargs[attr] = getattr(self.args, attr)
@@ -121,13 +122,14 @@ class TranscriptionEngine:
if self.args.diarization: if self.args.diarization:
if self.args.diarization_backend == "diart": if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization from whisperlivekit.diarization.diart_backend import DiartDiarization
self.diarization = DiartDiarization( self.diarization_model = DiartDiarization(
block_duration=self.args.min_chunk_size, block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model, segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model embedding_model_name=self.args.embedding_model
) )
elif self.args.diarization_backend == "sortformer": elif self.args.diarization_backend == "sortformer":
raise ValueError('Sortformer backend in developement') from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization()
else: else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}") raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
@@ -152,4 +154,16 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
confidence_validation = args.confidence_validation confidence_validation = args.confidence_validation
) )
return online return online
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online

View File

@@ -1,145 +1,457 @@
import numpy as np import numpy as np
import torch import torch
import logging import logging
import threading
import time
import wave
from typing import List, Optional
from queue import SimpleQueue, Empty
from whisperlivekit.timed_objects import SpeakerSegment from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from nemo.collections.asr.models import SortformerEncLabelModel from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError: except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""") raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
def __init__(self):
self.spkcache = None # Speaker cache to store embeddings from start
self.spkcache_lengths = None
self.spkcache_preds = None # speaker cache predictions
self.fifo = None # to save the embedding from the latest chunks
self.fifo_lengths = None
self.fifo_preds = None
self.spk_perm = None
self.mean_sil_emb = None
self.n_sil_frames = None
class SortformerDiarization: class SortformerDiarization:
def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"): def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name) """
self.diar_model.eval() Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
"""
self._load_model(model_name)
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda")) self.diar_model.to(torch.device("cuda"))
logger.info("Using CUDA for Sortformer model")
else:
logger.info("Using CPU for Sortformer model")
# Streaming parameters for speed self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.chunk_len = 12 self.diar_model.sortformer_modules.subsampling_factor = 10
self.diar_model.sortformer_modules.chunk_right_context = 1 self.diar_model.sortformer_modules.chunk_right_context = 0
self.diar_model.sortformer_modules.spkcache_len = 188 self.diar_model.sortformer_modules.chunk_left_context = 10
self.diar_model.sortformer_modules.fifo_len = 188 self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144 self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.log = False self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules._check_streaming_parameters() self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
self.batch_size = 1
self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device) except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000):
"""
Initialize the streaming Sortformer diarization system.
self.audio_buffer = np.array([], dtype=np.float32) Args:
self.sample_rate = 16000 sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
"""
self.sample_rate = sample_rate
self.speaker_segments = [] self.speaker_segments = []
self.buffer_audio = np.array([], dtype=np.float32)
self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state( self.segment_lock = threading.Lock()
batch_size=self.batch_size, self.global_time_offset = 0.0
async_streaming=True, self.processed_time = 0.0
device=self.diar_model.device self.debug = False
self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0
) )
self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
def _prepare_audio_signal(self, signal): self.diar_model.sortformer_modules.subsampling_factor *
audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device) self.diar_model.preprocessor._cfg.window_stride
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device)
processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)
return processed_signal, processed_signal_length
def _create_streaming_loader(self, processed_signal, processed_signal_length):
streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader(
feat_seq=processed_signal,
feat_seq_length=processed_signal_length,
feat_seq_offset=self.processed_signal_offset,
) )
return streaming_loader
self._init_streaming_state()
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
def _init_streaming_state(self):
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
with self.segment_lock:
self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray): async def diarize(self, pcm_array: np.ndarray):
""" """
Process an incoming audio chunk for diarization. Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
""" """
self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array]) try:
if self.debug:
# Process in fixed-size chunks (e.g., 1 second) self.audio_buffer.append(pcm_array.copy())
chunk_size = self.sample_rate # 1 second of audio
while len(self.audio_buffer) >= chunk_size:
chunk_to_process = self.audio_buffer[:chunk_size]
self.audio_buffer = self.audio_buffer[chunk_size:]
processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process) threshold = int(self.chunk_duration_seconds * self.sample_rate)
current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
if not len(self.buffer_audio) >= threshold:
streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length) return
frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride audio = self.buffer_audio[:threshold]
chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s self.buffer_audio = self.buffer_audio[threshold:]
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:]
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total_features = processed_signal_chunk
self._previous_chunk_features = processed_signal_chunk
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader: def _process_predictions(self):
with torch.inference_mode(): """Process model predictions and convert to speaker segments."""
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( try:
processed_signal=chunk_feat_seq_t, preds_np = self.total_preds[0].cpu().numpy()
processed_signal_length=feat_lengths, active_speakers = np.argmax(preds_np, axis=1)
streaming_state=self.streaming_state,
total_preds=self.total_preds, if self._len_prediction is None:
left_offset=left_offset, self._len_prediction = len(active_speakers)
right_offset=right_offset,
) # Get predictions for current chunk
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock:
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
num_new_frames = feat_lengths[0].item() # Check if this continues the last segment or starts a new one
if (self.speaker_segments and
# Get predictions for the current chunk from the end of total_preds self.speaker_segments[-1].speaker == spk and
preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy() abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
active_speakers = np.argmax(preds_np, axis=1) # Continue existing segment
self.speaker_segments[-1].end = end_time
for idx, spk in enumerate(active_speakers): else:
start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s)
end_time = start_time + frame_duration_s
if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1: # Create new segment
self.speaker_segments[-1].end = end_time self.speaker_segments.append(SpeakerSegment(
else: speaker=spk,
self.speaker_segments.append(SpeakerSegment( start=start_time,
speaker=int(spk + 1), end=end_time
start=start_time, ))
end=end_time
)) # Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
self.processed_signal_offset += processed_signal_length
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list:
""" """
Assign speakers to tokens based on timing overlap with speaker segments. Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
""" """
for token in tokens: with self.segment_lock:
for segment in self.speaker_segments: segments = self.speaker_segments.copy()
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.speaker_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self): def close(self):
""" """Close the diarization system and clean up resources."""
Cleanup resources. logger.info("Closing SortformerDiarization")
""" with self.segment_lock:
logger.info("Closing SortformerDiarization.") self.speaker_segments.clear()
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int:
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__': if __name__ == '__main__':
import asyncio
import librosa import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000) async def main():
"""TEST ONLY."""
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
diarization_pipeline = SortformerDiarization() print("\n" + "=" * 50)
print("ground truth:")
# Simulate streaming print("Speaker 0: 0:00 - 0:09")
chunk_size = 16000 # 1 second print("Speaker 1: 0:09 - 0:19")
for i in range(0, len(signal), chunk_size): print("Speaker 2: 0:19 - 0:25")
chunk = signal[i:i+chunk_size] print("Speaker 0: 0:25 - 0:30")
import asyncio print("=" * 50)
asyncio.run(diarization_pipeline.diarize(chunk))
diarization = SortformerDiarization(sample_rate=16000)
for segment in diarization_pipeline.speaker_segments: chunk_size = 1600
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())

View File

@@ -1,257 +0,0 @@
import numpy as np
import torch
import logging
import math
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
# diar_model.sortformer_modules.chunk_len = 6
# diar_model.sortformer_modules.spkcache_len = 188
# diar_model.sortformer_modules.chunk_right_context = 7
# diar_model.sortformer_modules.fifo_len = 188
# diar_model.sortformer_modules.spkcache_update_period = 144
# diar_model.sortformer_modules.log = False
# here we change the settings for our goal: speed!
# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12.
diar_model.sortformer_modules.chunk_len = 12
# for more speed, we reduce the 'right context'. it's like looking less into the future.
diar_model.sortformer_modules.chunk_right_context = 1
# we keep the rest same for now
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
batch_size = 1
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)
# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
# from nemo.collections.asr.modules.audio_preprocessing import get_features
from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor
def prepare_audio_signal(signal):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
return processed_signal, processed_signal_length
def streaming_feat_loader(
feat_seq, feat_seq_length, feat_seq_offset
):
"""
Load a chunk of feature sequence for streaming inference.
Args:
feat_seq (torch.Tensor): Tensor containing feature sequence
Shape: (batch_size, feat_dim, feat frame count)
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
Shape: (batch_size,)
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
Shape: (batch_size,)
Returns:
chunk_idx (int): Index of the current chunk
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
Shape: (batch_size, diar frame count, feat_dim)
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
Shape: (batch_size,)
"""
feat_len = feat_seq.shape[2]
num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor))
if False:
logging.info(
f"feat_len={feat_len}, num_chunks={num_chunks}, "
f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}"
)
stt_feat, end_feat, chunk_idx = 0, 0, 0
while end_feat < feat_len:
left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat)
end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len)
right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat)
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
0, chunk_feat_seq.shape[2]
)
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
stt_feat = end_feat
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
if False:
logging.info(
f"chunk_idx: {chunk_idx}, "
f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
f"chunk_feat_lengths: {feat_lengths}"
)
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
chunk_idx += 1
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(signal, chunks):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset)
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
print(f"Chunk duration: {chunk_duration_seconds} seconds")
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t, _, _, _ in streaming_loader:
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
print(chunk_feat_seq_t.shape, total_preds.shape)
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': i * chunk_duration_seconds + idx * frame_duration,
'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration,
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
print(l_speakers)
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
if __name__ == '__main__':
import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio,sr=16000)
"""
ground truth:
speaker 0 : 0:00 - 0:09
speaker 1 : 0:09 - 0:19
speaker 2 : 0:19 - 0:25
speaker 0 : 0:25 - end
"""
# Simulate streaming
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(signal, chunks)

View File

@@ -0,0 +1,205 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)

View File

@@ -61,7 +61,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--diarization-backend", "--diarization-backend",
type=str, type=str,
default="diart", default="sortformer",
choices=["sortformer", "diart"], choices=["sortformer", "diart"],
help="The diarization backend to use.", help="The diarization backend to use.",
) )
@@ -161,6 +161,14 @@ def parse_args():
# SimulStreaming-specific arguments # SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--frame-threshold", "--frame-threshold",

View File

@@ -81,7 +81,7 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
if not tokens: if not tokens:
return [], buffer_transcription, buffer_diarization return [], buffer_transcription, buffer_diarization
last_token = tokens[-1] last_token = tokens[-1]
if tokens and ( if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION current_time - last_token.end >= END_SILENCE_DURATION
or or
(current_time - last_token.end >= 3 and vac_detected_silence) (current_time - last_token.end >= 3 and vac_detected_silence)

View File

@@ -0,0 +1,138 @@
import logging
from datetime import timedelta
from whisperlivekit.remove_silences import handle_silences
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
CHECK_AROUND = 4
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
def is_punctuation(token):
if token.text.strip() in PUNCTUATION_MARKS:
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
last_end_diarized,
debug_info = ""
):
return {
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
if token.text:
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if not lines:
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
else:
previous_speaker = lines[-1]['speaker']
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
else:
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
if speaker != previous_speaker:
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
elif next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
# lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
pass
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
return lines, undiarized_text, buffer_transcription, ''

View File

@@ -13,15 +13,25 @@ import os
import gc import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try: try:
import torch from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
from whisperlivekit.simul_whisper.config import AlignAttConfig HAS_MLX_WHISPER = True
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper except ImportError:
from whisperlivekit.simul_whisper.whisper import tokenizer HAS_MLX_WHISPER = False
except ImportError as e: if HAS_MLX_WHISPER:
raise ImportError( HAS_FASTER_WHISPER = False
"""SimulStreaming dependencies are not available. else:
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""") try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# TOO_MANY_REPETITIONS = 3 # TOO_MANY_REPETITIONS = 3
@@ -42,6 +52,8 @@ class SimulStreamingOnlineProcessor:
self.committed: List[ASRToken] = [] self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = []
self.load_new_backend() self.load_new_backend()
#can be moved
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
@@ -49,7 +61,10 @@ class SimulStreamingOnlineProcessor:
model = self.asr.get_new_model_instance() model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper( self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg, cfg=self.asr.cfg,
loaded_model=model) loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset): def insert_silence(self, silence_duration, offset):
""" """
@@ -212,7 +227,7 @@ class SimulStreamingASR():
logger.warning(SIMULSTREAMING_LICENSE) logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile self.logfile = logfile
self.transcribe_kargs = {} self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan self.original_language = lan
self.model_path = kwargs.get('model_path', './large-v3.pt') self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25) self.frame_threshold = kwargs.get('frame_threshold', 25)
@@ -229,7 +244,8 @@ class SimulStreamingASR():
self.max_context_tokens = kwargs.get('max_context_tokens', None) self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None) self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1) self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None: if model_dir is not None:
self.model_path = model_dir self.model_path = model_dir
elif modelsize is not None: elif modelsize is not None:
@@ -249,11 +265,6 @@ class SimulStreamingASR():
} }
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
model_path=self.model_path, model_path=self.model_path,
segment_length=self.segment_length, segment_length=self.segment_length,
@@ -271,17 +282,52 @@ class SimulStreamingASR():
static_init_prompt=self.static_init_prompt, static_init_prompt=self.static_init_prompt,
) )
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self): def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path) whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language) if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model return whisper_model
def get_new_model_instance(self): def get_new_model_instance(self):
@@ -301,10 +347,12 @@ class SimulStreamingASR():
def set_translate_task(self): def set_translate_task(self):
"""Set up translation task.""" """Set up translation task."""
if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer( return tokenizer.get_tokenizer(
multilingual=True, multilingual=True,
language=self.model.cfg.language, language=self.cfg.language,
num_languages=self.model.model.num_languages, num_languages=99,
task="translate" task="translate"
) )

View File

@@ -0,0 +1,72 @@
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from mlx_whisper import whisper
mlx_model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
def load_mlx_encoder(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
# we only want to load the encoder weights here.
# Size examples: for tiny.en,
# Decoder weights: 59110771 bytes
# Encoder weights: 15268874 bytes
encoder_weights = {}
encoder_weights['encoder'] = weights['encoder']
del(weights)
model.update(encoder_weights)
mx.eval(model.parameters())
return model

View File

@@ -14,7 +14,7 @@ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens,
from .beam import BeamPyTorchInference from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif from .eow_detection import fire_at_boundary, load_cif
import os import os
from time import time
from .token_buffer import TokenBuffer from .token_buffer import TokenBuffer
import numpy as np import numpy as np
@@ -23,8 +23,22 @@ from .generation_progress import *
DEC_PAD = 50257 DEC_PAD = 50257
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import sys
import wave try:
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
HAS_MLX_WHISPER = True
except ImportError:
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper: # New features added to the original version of Simul-Whisper:
# - large-v3 model support # - large-v3 model support
@@ -33,7 +47,13 @@ import wave
# - prompt -- static vs. non-static # - prompt -- static vs. non-static
# - context # - context
class PaddedAlignAttWhisper: class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig, loaded_model=None) -> None: def __init__(
self,
cfg: AlignAttConfig,
loaded_model=None,
mlx_encoder=None,
fw_encoder=None,
) -> None:
self.log_segments = 0 self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "") model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path)) model_path = os.path.dirname(os.path.abspath(cfg.model_path))
@@ -42,6 +62,11 @@ class PaddedAlignAttWhisper:
else: else:
self.model = load_model(name=model_name, download_root=model_path) self.model = load_model(name=model_name, download_root=model_path)
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions( self.decode_options = DecodingOptions(
@@ -151,6 +176,15 @@ class PaddedAlignAttWhisper:
for hook in self.l_hooks: for hook in self.l_hooks:
hook.remove() hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None): def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer( self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual, multilingual=self.tokenizer_is_multilingual,
@@ -359,20 +393,36 @@ class PaddedAlignAttWhisper:
else: else:
input_segments = self.segments[0] input_segments = self.segments[0]
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
# mel + padding to 30s if self.mlx_encoder:
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
device=self.model.device).unsqueeze(0) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
# trim to 3000 mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
mel = pad_or_trim(mel_padded, N_FRAMES) encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
# the len of actual audio device = 'cpu'
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
# encode content_mel_len = int(audio_length_seconds * 100)//2
encoder_feature = self.model.encoder(mel) mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
device = 'cpu'
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
device = mel.device
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}") # logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): # if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# logger.debug("mel ") # logger.debug("mel ")
@@ -397,7 +447,7 @@ class PaddedAlignAttWhisper:
####################### Decoding loop ####################### Decoding loop
logger.info("Decoding loop starts\n") logger.info("Decoding loop starts\n")
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device) sum_logprobs = torch.zeros(self.cfg.beam_size, device=device)
completed = False completed = False
attn_of_alignment_heads = None attn_of_alignment_heads = None

View File

@@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
decoder_only=False
) -> Whisper: ) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@@ -151,7 +152,14 @@ def load_model(
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:

View File

@@ -253,16 +253,18 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels, if not decoder_only:
self.dims.n_audio_ctx, self.encoder = AudioEncoder(
self.dims.n_audio_state, self.dims.n_mels,
self.dims.n_audio_head, self.dims.n_audio_ctx,
self.dims.n_audio_layer, self.dims.n_audio_state,
) self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder( self.decoder = TextDecoder(
self.dims.n_vocab, self.dims.n_vocab,
self.dims.n_text_ctx, self.dims.n_text_ctx,

View File

@@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5):
logger.debug(f"Download successful in {time.time() - start_time:.2f}s") logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e: except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.") logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False return None
finally: finally:
socket.setdefaulttimeout(original_timeout) socket.setdefaulttimeout(original_timeout)
elif not warmup_file: elif not warmup_file:
return False return None
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0: if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.") logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False return None
try: try:
audio, sr = librosa.load(warmup_file, sr=16000) audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e: except Exception as e:
logger.warning(f"Failed to load audio file: {e}") logger.warning(f"Failed to load audio file: {e}")
return False return None
return audio return audio
def warmup_asr(asr, warmup_file=None, timeout=5): def warmup_asr(asr, warmup_file=None, timeout=5):

View File

@@ -184,7 +184,7 @@ body {
.settings { .settings {
display: flex; display: flex;
flex-direction: column; flex-wrap: wrap;
align-items: flex-start; align-items: flex-start;
gap: 12px; gap: 12px;
} }
@@ -198,23 +198,27 @@ body {
#chunkSelector, #chunkSelector,
#websocketInput, #websocketInput,
#themeSelector { #themeSelector,
#microphoneSelect {
font-size: 16px; font-size: 16px;
padding: 5px 8px; padding: 5px 8px;
border-radius: 8px; border-radius: 8px;
border: 1px solid var(--border); border: 1px solid var(--border);
background-color: var(--button-bg); background-color: var(--button-bg);
color: var(--text); color: var(--text);
max-height: 34px; max-height: 30px;
} }
#websocketInput { #microphoneSelect {
width: 220px; width: 100%;
max-width: 190px;
min-width: 120px;
} }
#chunkSelector:focus, #chunkSelector:focus,
#websocketInput:focus, #websocketInput:focus,
#themeSelector:focus { #themeSelector:focus,
#microphoneSelect:focus {
outline: none; outline: none;
border-color: #007bff; border-color: #007bff;
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15); box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
@@ -247,9 +251,9 @@ label {
} }
.theme-selector-container { .theme-selector-container {
position: absolute; display: flex;
top: 20px; align-items: center;
right: 20px; margin-top: 17px;
} }
.segmented label { .segmented label {
@@ -400,3 +404,57 @@ label {
font-size: 14px; font-size: 14px;
margin-bottom: 0px; margin-bottom: 0px;
} }
/* for smaller screens */
@media (max-width: 768px) {
.settings-container {
flex-direction: column;
gap: 10px;
}
.settings {
justify-content: center;
gap: 8px;
}
.field {
align-items: center;
}
#websocketInput,
#microphoneSelect {
min-width: 100px;
max-width: 160px;
}
.theme-selector-container {
margin-top: 10px;
}
}
@media (max-width: 480px) {
body {
margin: 10px;
}
.settings {
flex-direction: column;
align-items: center;
gap: 6px;
}
#websocketInput,
#microphoneSelect {
max-width: 140px;
}
.segmented label {
padding: 4px 8px;
font-size: 12px;
}
.segmented img {
width: 14px;
height: 14px;
}
}

View File

@@ -1,61 +1,73 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WhisperLiveKit</title> <title>WhisperLiveKit</title>
<link rel="stylesheet" href="/web/live_transcription.css" /> <link rel="stylesheet" href="/web/live_transcription.css" />
</head> </head>
<body> <body>
<div class="settings-container"> <div class="settings-container">
<button id="recordButton"> <button id="recordButton">
<div class="shape-container"> <div class="shape-container">
<div class="shape"></div> <div class="shape"></div>
</div> </div>
<div class="recording-info"> <div class="recording-info">
<div class="wave-container"> <div class="wave-container">
<canvas id="waveCanvas"></canvas> <canvas id="waveCanvas"></canvas>
</div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">Websocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
<div class="field">
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
<select id="microphoneSelect">
<option value="">Default Microphone</option>
</select>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div>
</div>
</div> </div>
<div class="timer">00:00</div>
</div>
</button>
<div class="settings">
<div class="field">
<label for="websocketInput">WebSocket URL</label>
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
</div>
</div>
</div> </div>
</div>
<div class="theme-selector-container">
<div class="segmented" role="radiogroup" aria-label="Theme selector">
<input type="radio" id="theme-system" name="theme" value="system" />
<label for="theme-system" title="System">
<img src="/web/src/system_mode.svg" alt="" />
<span>System</span>
</label>
<input type="radio" id="theme-light" name="theme" value="light" />
<label for="theme-light" title="Light">
<img src="/web/src/light_mode.svg" alt="" />
<span>Light</span>
</label>
<input type="radio" id="theme-dark" name="theme" value="dark" />
<label for="theme-dark" title="Dark">
<img src="/web/src/dark_mode.svg" alt="" />
<span>Dark</span>
</label>
</div> </div>
</div>
<p id="status"></p>
<div id="linesTranscript"></div>
<script src="/web/live_transcription.js"></script> <p id="status"></p>
<div id="linesTranscript"></div>
<script src="/web/live_transcription.js"></script>
</body> </body>
</html>
</html>

View File

@@ -18,6 +18,8 @@ let animationFrame = null;
let waitingForStop = false; let waitingForStop = false;
let lastReceivedData = null; let lastReceivedData = null;
let lastSignature = null; let lastSignature = null;
let availableMicrophones = [];
let selectedMicrophoneId = null;
waveCanvas.width = 60 * (window.devicePixelRatio || 1); waveCanvas.width = 60 * (window.devicePixelRatio || 1);
waveCanvas.height = 30 * (window.devicePixelRatio || 1); waveCanvas.height = 30 * (window.devicePixelRatio || 1);
@@ -31,6 +33,7 @@ const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
const linesTranscriptDiv = document.getElementById("linesTranscript"); const linesTranscriptDiv = document.getElementById("linesTranscript");
const timerElement = document.querySelector(".timer"); const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]'); const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect");
function getWaveStroke() { function getWaveStroke() {
const styles = getComputedStyle(document.documentElement); const styles = getComputedStyle(document.documentElement);
@@ -82,6 +85,61 @@ if (darkMq && darkMq.addEventListener) {
darkMq.addListener(handleOsThemeChange); darkMq.addListener(handleOsThemeChange);
} }
async function enumerateMicrophones() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach(track => track.stop());
const devices = await navigator.mediaDevices.enumerateDevices();
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
populateMicrophoneSelect();
console.log(`Found ${availableMicrophones.length} microphone(s)`);
} catch (error) {
console.error('Error enumerating microphones:', error);
statusText.textContent = "Error accessing microphones. Please grant permission.";
}
}
function populateMicrophoneSelect() {
if (!microphoneSelect) return;
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
availableMicrophones.forEach((device, index) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${index + 1}`;
microphoneSelect.appendChild(option);
});
const savedMicId = localStorage.getItem('selectedMicrophone');
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
microphoneSelect.value = savedMicId;
selectedMicrophoneId = savedMicId;
}
}
function handleMicrophoneChange() {
selectedMicrophoneId = microphoneSelect.value || null;
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
console.log(`Selected microphone: ${deviceName}`);
statusText.textContent = `Microphone changed to: ${deviceName}`;
if (isRecording) {
statusText.textContent = "Switching microphone... Please wait.";
stopRecording().then(() => {
setTimeout(() => {
toggleRecording();
}, 1000);
});
}
}
// Helpers // Helpers
function fmt1(x) { function fmt1(x) {
const n = Number(x); const n = Number(x);
@@ -377,7 +435,11 @@ async function startRecording() {
console.log("Error acquiring wake lock."); console.log("Error acquiring wake lock.");
} }
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); const audioConstraints = selectedMicrophoneId
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
: { audio: true };
const stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
audioContext = new (window.AudioContext || window.webkitAudioContext)(); audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser(); analyser = audioContext.createAnalyser();
@@ -400,7 +462,12 @@ async function startRecording() {
isRecording = true; isRecording = true;
updateUI(); updateUI();
} catch (err) { } catch (err) {
statusText.textContent = "Error accessing microphone. Please allow microphone access."; if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
console.error(err); console.error(err);
} }
} }
@@ -511,3 +578,22 @@ function updateUI() {
} }
recordButton.addEventListener("click", toggleRecording); recordButton.addEventListener("click", toggleRecording);
if (microphoneSelect) {
microphoneSelect.addEventListener("change", handleMicrophoneChange);
}
document.addEventListener('DOMContentLoaded', async () => {
try {
await enumerateMicrophones();
} catch (error) {
console.log("Could not enumerate microphones on load:", error);
}
});
navigator.mediaDevices.addEventListener('devicechange', async () => {
console.log('Device change detected, re-enumerating microphones');
try {
await enumerateMicrophones();
} catch (error) {
console.log("Error re-enumerating microphones:", error);
}
});

View File

@@ -1,5 +1,6 @@
import logging import logging
import importlib.resources as resources import importlib.resources as resources
import base64
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -12,6 +13,60 @@ def get_web_interface_html():
logger.error(f"Error loading web interface HTML: {e}") logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>" return "<html><body><h1>Error loading interface</h1></body></html>"
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
# SVG files
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
# Replace external references
html_content = html_content.replace(
'<link rel="stylesheet" href="/web/live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
)
html_content = html_content.replace(
'<script src="/web/live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/light_mode.svg" alt="" />',
f'<img src="{light_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/dark_mode.svg" alt="" />',
f'<img src="{dark_data_uri}" alt="" />'
)
return html_content
except Exception as e:
logger.error(f"Error creating embedded web interface: {e}")
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
if __name__ == '__main__': if __name__ == '__main__':
@@ -28,6 +83,6 @@ if __name__ == '__main__':
@app.get("/") @app.get("/")
async def get(): async def get():
return HTMLResponse(get_web_interface_html()) return HTMLResponse(get_inline_ui_html())
uvicorn.run(app=app) uvicorn.run(app=app)