20 Commits

Author SHA1 Message Date
Quentin Fuxa
aa44a92a67 add embedded web interface HTML (single-file version with inline CSS/JS/SVG)
### Added
- `get_inline_ui_html()`: generates a self-contained version of the web interface, with CSS, JS, and SVG assets inlined directly into the HTML. useful for environments where serving static files is inconvenient or when a single-call UI delivery is preferred.
2025-08-29 21:58:51 +02:00
Quentin Fuxa
01d791470b add test files 2025-08-29 17:45:32 +02:00
Quentin Fuxa
4a5d5e1f3b raise Exception when language == auto and task == translation 2025-08-29 17:44:46 +02:00
Quentin Fuxa
583a2ec2e4 highlight Sortformer optional installation 2025-08-27 21:02:25 +02:00
Quentin Fuxa
19765e89e9 remove triton <3 condition 2025-08-27 20:44:39 +02:00
Quentin Fuxa
9895bc83bf auto detection of language for warmup if not indicated 2025-08-27 20:37:48 +02:00
Quentin Fuxa
ab98c31f16 trim will happen before audio processor 2025-08-27 18:17:11 +02:00
Quentin Fuxa
f9c9c4188a optional dependencies removed, ask to direct alternative package installations 2025-08-27 18:15:32 +02:00
Quentin Fuxa
c21d2302e7 to 0.2.7 2024-08-24 19:28:00 +02:00
Quentin Fuxa
4ed62e181d when silences are detected, speaker correction is no more applied 2024-08-24 19:24:00 +02:00
Quentin Fuxa
52a755a08c indications on how to choose a model 2024-08-24 19:22:00 +02:00
Quentin Fuxa
9a8d3cbd90 improve diarization + silence handling 2024-08-24 19:20:00 +02:00
Quentin Fuxa
b101ce06bd several users share the same sortformer model instance 2024-08-24 19:18:00 +02:00
Quentin Fuxa
c83fd179a8 improves phase shift correction between transcription and diarization 2024-08-24 19:15:00 +02:00
Quentin Fuxa
5258305745 default diarization backend in now sortformer 2025-08-24 18:32:01 +02:00
Quentin Fuxa
ce781831ee punctuation is checked in audio-processor's result formatter 2025-08-24 18:32:01 +02:00
Quentin Fuxa
58297daf6d sortformer diar implementation v0.3 2025-08-24 18:32:01 +02:00
Quentin Fuxa
3393a08f7e sortformer diar implementation v0.2 2025-08-24 18:32:01 +02:00
Quentin Fuxa
5b2ddeccdb correct pip installation error in image build 2025-08-22 15:37:46 +02:00
Quentin Fuxa
26cc1072dd new dockerfile for cpu only. update dockerfile from cuda 12.8 to 12.9 2025-08-22 11:04:35 +02:00
22 changed files with 1778 additions and 505 deletions

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
@@ -9,24 +9,20 @@ ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
# Install system dependencies
#RUN apt-get update && \
# apt-get install -y ffmpeg git && \
# apt-get clean && \
# rm -rf /var/lib/apt/lists/*
# 2) Install system dependencies + Python + pip
RUN apt-get update && \
apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-venv \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129
COPY . .
@@ -35,10 +31,10 @@ COPY . .
# for more details.
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir .[$EXTRAS]; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir .; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models by:
@@ -81,4 +77,4 @@ EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args
CMD ["--model", "base"]
CMD ["--model", "medium"]

61
Dockerfile.cpu Normal file
View File

@@ -0,0 +1,61 @@
FROM python:3.13-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
WORKDIR /app
ARG EXTRAS
ARG HF_PRECACHE_DIR
ARG HF_TKN_FILE
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
git \
build-essential \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies
RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
else \
echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \
fi
# Enable in-container caching for Hugging Face models
VOLUME ["/root/.cache/huggingface/hub"]
# Conditionally copy a local pre-cache from the build context
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
mkdir -p /root/.cache/huggingface/hub && \
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
else \
echo "No local Hugging Face cache specified, skipping copy"; \
fi
# Conditionally copy a Hugging Face token if provided
RUN if [ -n "$HF_TKN_FILE" ]; then \
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
mkdir -p /root/.cache/huggingface && \
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
else \
echo "No Hugging Face token file specified, skipping token setup"; \
fi
# Expose port for the transcription server
EXPOSE 8000
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
# Default args - you might want to use a smaller model for CPU
CMD ["--model", "tiny"]

101
README.md
View File

@@ -8,7 +8,7 @@
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
@@ -66,41 +66,31 @@ pip install whisperlivekit
| Optional | `pip install` |
|-----------|-------------|
| Speaker diarization | `whisperlivekit[diarization]` |
| Original Whisper backend | `whisperlivekit[whisper]` |
| Improved timestamps backend | `whisperlivekit[whisper-timestamped]` |
| Apple Silicon optimization backend | `whisperlivekit[mlx-whisper]` |
| OpenAI API backend | `whisperlivekit[openai]` |
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` |
| Original Whisper backend | `whisper` |
| Improved timestamps backend | `whisper-timestamped` |
| Apple Silicon optimization backend | `mlx-whisper` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them.
> **Pyannote Models Setup** For diarization, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace:
> ```bash
> huggingface-cli login
> ```
## 💻 Usage Examples
#### Command-line Interface
### Usage Examples
Start the transcription server with various options:
**Command-line Interface**: Start the transcription server with various options:
```bash
# SimulStreaming backend for ultra-low latency
whisperlivekit-server --backend simulstreaming --model large-v3
# Use better model than default (small)
whisperlivekit-server --model large-v3
# Advanced configuration with diarization
# Advanced configuration with diarization and language
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
```
#### Python API Integration (Backend)
Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
```python
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
@@ -138,17 +128,25 @@ async def websocket_endpoint(websocket: WebSocket):
await audio_processor.process_audio(message)
```
#### Frontend Implementation
The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()`
### ⚙️ Parameters & Configuration
## Parameters & Configuration
An important list of parameters can be changed. But what *should* you change?
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
The rest I don't recommend. But below are your options.
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--model` | Whisper model size. | `small` |
| `--language` | Source language code or `auto` | `en` |
| `--language` | Source language code or `auto` | `auto` |
| `--task` | `transcribe` or `translate` | `transcribe` |
| `--backend` | Processing backend | `simulstreaming` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
@@ -185,9 +183,16 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
> For diarization using Diart, you need access to pyannote.audio models:
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
> 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
> 3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
>4. Login with HuggingFace: `huggingface-cli login`
### 🚀 Deployment Guide
@@ -216,19 +221,39 @@ To deploy WhisperLiveKit in production:
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
### 🐋 Docker
## 🐋 Docker
A Dockerfile is provided which allows re-use of Python package installation options. Create a reusable image with only the basics and then run as a named container:
Deploy the application easily using Docker with GPU or CPU support.
### Prerequisites
- Docker installed on your system
- For GPU support: NVIDIA Docker runtime installed
### Quick Start
**With GPU acceleration (recommended):**
```bash
docker build -t whisperlivekit-defaults .
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults --model base
docker start -i whisperlivekit
docker build -t wlk .
docker run --gpus all -p 8000:8000 --name wlk wlk
```
> **Note**: For **large** models, ensure that your **docker runtime** has enough **memory** available
**CPU only:**
```bash
docker build -f Dockerfile.cpu -t wlk .
docker run -p 8000:8000 --name wlk wlk
```
### Advanced Usage
**Custom configuration:**
```bash
# Example with custom model and language
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
```
### Memory Requirements
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
#### Customization

72
available_models.md Normal file
View File

@@ -0,0 +1,72 @@
# Available model sizes:
- tiny.en (english only)
- tiny
- base.en (english only)
- base
- small.en (english only)
- small
- medium.en (english only)
- medium
- large-v1
- large-v2
- large-v3
- large-v3-turbo
## How to choose?
### Language Support
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
- **Multilingual**: Do not use `.en` models.
### Resource Constraints
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
- `base`: Good balance of speed and accuracy for basic use cases
- `small`: Better accuracy while still being resource-efficient
- **Good resources available**: Use `large` models for best accuracy
- `large-v2`: Excellent accuracy, good multilingual support
- `large-v3`: Best overall accuracy and language support
### Special Cases
- **No translation needed**: Use `large-v3-turbo`
- Same transcription quality as `large-v2` but significantly faster
- **Important**: Does not translate correctly, only transcribes
### Model Comparison Table
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|-------|--------|----------|--------------|-------------|---------------|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
### Additional Considerations
**Model Performance**:
- Accuracy improves significantly from tiny to large models
- English-only models are ~10-15% more accurate for English audio
- Newer versions (v2, v3) have better punctuation and formatting
**Hardware Requirements**:
- `tiny`: ~1GB VRAM
- `base`: ~1GB VRAM
- `small`: ~2GB VRAM
- `medium`: ~5GB VRAM
- `large`: ~10GB VRAM
**Audio Quality Impact**:
- Clean, clear audio: smaller models may suffice
- Noisy, accented, or technical audio: larger models recommended
- Phone/low-quality audio: use at least `small` model
### Quick Decision Tree
1. English only? → Add `.en` to your choice
2. Limited resources or need speed? → `small` or smaller
3. Good hardware and want best quality? → `large-v3`
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.6"
version = "0.2.7"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
readme = "README.md"
authors = [
@@ -31,16 +31,11 @@ dependencies = [
"torch",
"tqdm",
"tiktoken",
'triton>=2.0.0,<3; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
]
[project.optional-dependencies]
diarization = ["diart"]
sentence = ["mosestokenizer", "wtpsplit"]
whisper = ["whisper"]
whisper-timestamped = ["whisper-timestamped"]
mlx-whisper = ["mlx-whisper"]
openai = ["openai"]
[project.urls]
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"

View File

@@ -1,12 +1,13 @@
from .audio_processor import AudioProcessor
from .core import TranscriptionEngine
from .parse_args import parse_args
from .web.web_interface import get_web_interface_html
from .web.web_interface import get_web_interface_html, get_inline_ui_html
__all__ = [
"TranscriptionEngine",
"AudioProcessor",
"parse_args",
"get_web_interface_html",
"get_inline_ui_html",
"download_simulstreaming_backend",
]

View File

@@ -4,13 +4,11 @@ from time import time, sleep
import math
import logging
import traceback
from datetime import timedelta
from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.trail_repetition import trim_tail_repetition
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output, format_time
# Set up logging once
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -18,10 +16,6 @@ logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
@@ -66,7 +60,6 @@ class AudioProcessor:
# Models and processing
self.asr = models.asr
self.tokenizer = models.tokenizer
self.diarization = models.diarization
self.vac_model = models.vac_model
if self.args.vac:
self.vac = FixedVADIterator(models.vac_model)
@@ -99,6 +92,11 @@ class AudioProcessor:
# Initialize transcription engine if enabled
if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer)
# Initialize diarization engine if enabled
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array."""
@@ -108,17 +106,6 @@ class AudioProcessor:
"""Thread-safe update of transcription with new data."""
async with self.lock:
self.tokens.extend(new_tokens)
# self.tokens, has_been_trimmed = trim_tail_repetition(
# self.tokens,
# key=lambda t: t.text.strip().lower(),
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
# max_tail=200,
# prefer="longest", # prefer removing the longest repeated phrase
# keep=1
# )
# if has_been_trimmed:
# print('HAS BEEN TRIMMED !')
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.sep = sep
@@ -133,7 +120,7 @@ class AudioProcessor:
async def add_dummy_token(self):
"""Placeholder token when no transcription is available."""
async with self.lock:
current_time = time() - self.beg_loop
current_time = time() - self.beg_loop if self.beg_loop else 0
self.tokens.append(ASRToken(
start=current_time, end=current_time + 1,
text=".", speaker=-1, is_dummy=True
@@ -303,12 +290,12 @@ class AudioProcessor:
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens:
asr_processing_logs += " | last_end = {self.tokens[-1].end} |"
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs)
if type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end)
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue
if isinstance(item, np.ndarray):
@@ -433,7 +420,7 @@ class AudioProcessor:
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
# Add dummy tokens if needed
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
await self.add_dummy_token()
@@ -442,45 +429,13 @@ class AudioProcessor:
tokens = state["tokens"]
# Format output
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []
current_time = time() - self.beg_loop if self.beg_loop else None
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence)
for token in tokens:
speaker = token.speaker
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
# Handle diarization
if self.args.diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if self.debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
lines.append({
"speaker": speaker,
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
state,
self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None,
diarization = self.args.diarization,
debug = self.debug
)
# Handle undiarized text
if undiarized_text:
combined = sep.join(undiarized_text)
@@ -510,7 +465,7 @@ class AudioProcessor:
"buffer_transcription": buffer_transcription,
"buffer_diarization": buffer_diarization,
"remaining_time_transcription": state["remaining_time_transcription"],
"remaining_time_diarization": state["remaining_time_diarization"]
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
}
current_response_signature = f"{response_status} | " + \

View File

@@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio
import logging
from starlette.staticfiles import StaticFiles
@@ -38,7 +38,7 @@ app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/")
async def get():
return HTMLResponse(get_web_interface_html())
return HTMLResponse(get_inline_ui_html())
async def handle_websocket_results(websocket, results_generator):
@@ -52,7 +52,7 @@ async def handle_websocket_results(websocket, results_generator):
except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e:
logger.error(f"Error in WebSocket results handler: {e}")
logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr")

View File

@@ -57,7 +57,7 @@ class TranscriptionEngine:
"static_init_prompt": None,
"max_context_tokens": None,
"model_path": './base.pt',
"diarization_backend": "diart",
"diarization_backend": "sortformer",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
@@ -121,13 +121,14 @@ class TranscriptionEngine:
if self.args.diarization:
if self.args.diarization_backend == "diart":
from whisperlivekit.diarization.diart_backend import DiartDiarization
self.diarization = DiartDiarization(
self.diarization_model = DiartDiarization(
block_duration=self.args.min_chunk_size,
segmentation_model_name=self.args.segmentation_model,
embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
raise ValueError('Sortformer backend in developement')
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization_model = SortformerDiarization()
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
@@ -152,4 +153,16 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
confidence_validation = args.confidence_validation
)
return online
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online

View File

@@ -1,145 +1,457 @@
import numpy as np
import torch
import logging
import threading
import time
import wave
from typing import List, Optional
from queue import SimpleQueue, Empty
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
def __init__(self):
self.spkcache = None # Speaker cache to store embeddings from start
self.spkcache_lengths = None
self.spkcache_preds = None # speaker cache predictions
self.fifo = None # to save the embedding from the latest chunks
self.fifo_lengths = None
self.fifo_preds = None
self.spk_perm = None
self.mean_sil_emb = None
self.n_sil_frames = None
class SortformerDiarization:
def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"):
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
"""
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
"""
self._load_model(model_name)
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda"))
if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda"))
logger.info("Using CUDA for Sortformer model")
else:
logger.info("Using CPU for Sortformer model")
# Streaming parameters for speed
self.diar_model.sortformer_modules.chunk_len = 12
self.diar_model.sortformer_modules.chunk_right_context = 1
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
self.batch_size = 1
self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device)
self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.subsampling_factor = 10
self.diar_model.sortformer_modules.chunk_right_context = 0
self.diar_model.sortformer_modules.chunk_left_context = 10
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000):
"""
Initialize the streaming Sortformer diarization system.
self.audio_buffer = np.array([], dtype=np.float32)
self.sample_rate = 16000
Args:
sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
"""
self.sample_rate = sample_rate
self.speaker_segments = []
self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state(
batch_size=self.batch_size,
async_streaming=True,
device=self.diar_model.device
self.buffer_audio = np.array([], dtype=np.float32)
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.processed_time = 0.0
self.debug = False
self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0
)
self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device)
def _prepare_audio_signal(self, signal):
audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device)
processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)
return processed_signal, processed_signal_length
def _create_streaming_loader(self, processed_signal, processed_signal_length):
streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader(
feat_seq=processed_signal,
feat_seq_length=processed_signal_length,
feat_seq_offset=self.processed_signal_offset,
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride
)
return streaming_loader
self._init_streaming_state()
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
def _init_streaming_state(self):
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
with self.segment_lock:
self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray):
"""
Process an incoming audio chunk for diarization.
Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
"""
self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array])
# Process in fixed-size chunks (e.g., 1 second)
chunk_size = self.sample_rate # 1 second of audio
while len(self.audio_buffer) >= chunk_size:
chunk_to_process = self.audio_buffer[:chunk_size]
self.audio_buffer = self.audio_buffer[chunk_size:]
try:
if self.debug:
self.audio_buffer.append(pcm_array.copy())
processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process)
threshold = int(self.chunk_duration_seconds * self.sample_rate)
current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride
streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length)
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
if not len(self.buffer_audio) >= threshold:
return
frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride
chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:]
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total_features = processed_signal_chunk
self._previous_chunk_features = processed_signal_chunk
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:
with torch.inference_mode():
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=feat_lengths,
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
def _process_predictions(self):
"""Process model predictions and convert to speaker segments."""
try:
preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers)
# Get predictions for current chunk
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock:
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
num_new_frames = feat_lengths[0].item()
# Get predictions for the current chunk from the end of total_preds
preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
for idx, spk in enumerate(active_speakers):
start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s)
end_time = start_time + frame_duration_s
# Check if this continues the last segment or starts a new one
if (self.speaker_segments and
self.speaker_segments[-1].speaker == spk and
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
# Continue existing segment
self.speaker_segments[-1].end = end_time
else:
if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1:
self.speaker_segments[-1].end = end_time
else:
self.speaker_segments.append(SpeakerSegment(
speaker=int(spk + 1),
start=start_time,
end=end_time
))
self.processed_signal_offset += processed_signal_length
# Create new segment
self.speaker_segments.append(SpeakerSegment(
speaker=spk,
start=start_time,
end=end_time
))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list:
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
for token in tokens:
for segment in self.speaker_segments:
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker
with self.segment_lock:
segments = self.speaker_segments.copy()
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.speaker_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self):
"""
Cleanup resources.
"""
logger.info("Closing SortformerDiarization.")
"""Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization")
with self.segment_lock:
self.speaker_segments.clear()
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
def extract_number(s: str) -> int:
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__':
import asyncio
import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
async def main():
"""TEST ONLY."""
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
diarization_pipeline = SortformerDiarization()
# Simulate streaming
chunk_size = 16000 # 1 second
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
import asyncio
asyncio.run(diarization_pipeline.diarize(chunk))
for segment in diarization_pipeline.speaker_segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
print("\n" + "=" * 50)
print("ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
diarization = SortformerDiarization(sample_rate=16000)
chunk_size = 1600
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())

View File

@@ -1,257 +0,0 @@
import numpy as np
import torch
import logging
import math
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
# diar_model.sortformer_modules.chunk_len = 6
# diar_model.sortformer_modules.spkcache_len = 188
# diar_model.sortformer_modules.chunk_right_context = 7
# diar_model.sortformer_modules.fifo_len = 188
# diar_model.sortformer_modules.spkcache_update_period = 144
# diar_model.sortformer_modules.log = False
# here we change the settings for our goal: speed!
# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12.
diar_model.sortformer_modules.chunk_len = 12
# for more speed, we reduce the 'right context'. it's like looking less into the future.
diar_model.sortformer_modules.chunk_right_context = 1
# we keep the rest same for now
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
batch_size = 1
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)
# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
# from nemo.collections.asr.modules.audio_preprocessing import get_features
from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor
def prepare_audio_signal(signal):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
return processed_signal, processed_signal_length
def streaming_feat_loader(
feat_seq, feat_seq_length, feat_seq_offset
):
"""
Load a chunk of feature sequence for streaming inference.
Args:
feat_seq (torch.Tensor): Tensor containing feature sequence
Shape: (batch_size, feat_dim, feat frame count)
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
Shape: (batch_size,)
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
Shape: (batch_size,)
Returns:
chunk_idx (int): Index of the current chunk
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
Shape: (batch_size, diar frame count, feat_dim)
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
Shape: (batch_size,)
"""
feat_len = feat_seq.shape[2]
num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor))
if False:
logging.info(
f"feat_len={feat_len}, num_chunks={num_chunks}, "
f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}"
)
stt_feat, end_feat, chunk_idx = 0, 0, 0
while end_feat < feat_len:
left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat)
end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len)
right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat)
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
0, chunk_feat_seq.shape[2]
)
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
stt_feat = end_feat
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
if False:
logging.info(
f"chunk_idx: {chunk_idx}, "
f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
f"chunk_feat_lengths: {feat_lengths}"
)
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
chunk_idx += 1
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(signal, chunks):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset)
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
print(f"Chunk duration: {chunk_duration_seconds} seconds")
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t, _, _, _ in streaming_loader:
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
print(chunk_feat_seq_t.shape, total_preds.shape)
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': i * chunk_duration_seconds + idx * frame_duration,
'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration,
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
print(l_speakers)
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
if __name__ == '__main__':
import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio,sr=16000)
"""
ground truth:
speaker 0 : 0:00 - 0:09
speaker 1 : 0:09 - 0:19
speaker 2 : 0:19 - 0:25
speaker 0 : 0:25 - end
"""
# Simulate streaming
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(signal, chunks)

View File

@@ -0,0 +1,205 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)

View File

@@ -61,7 +61,7 @@ def parse_args():
parser.add_argument(
"--diarization-backend",
type=str,
default="diart",
default="sortformer",
choices=["sortformer", "diart"],
help="The diarization backend to use.",
)

View File

@@ -81,7 +81,7 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
if not tokens:
return [], buffer_transcription, buffer_diarization
last_token = tokens[-1]
if tokens and (
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
(current_time - last_token.end >= 3 and vac_detected_silence)

View File

@@ -0,0 +1,138 @@
import logging
from datetime import timedelta
from whisperlivekit.remove_silences import handle_silences
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
CHECK_AROUND = 4
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
def is_punctuation(token):
if token.text.strip() in PUNCTUATION_MARKS:
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
last_end_diarized,
debug_info = ""
):
return {
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
if token.text:
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]
sep = state["sep"]
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
continue
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
speaker = previous_speaker
if speaker not in [-1, 0]:
last_end_diarized = max(token.end, last_end_diarized)
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if not lines:
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
else:
previous_speaker = lines[-1]['speaker']
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
else:
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
if speaker != previous_speaker:
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
elif next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
# lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
pass
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
return lines, undiarized_text, buffer_transcription, ''

View File

@@ -42,6 +42,8 @@ class SimulStreamingOnlineProcessor:
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
#can be moved
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
@@ -212,7 +214,7 @@ class SimulStreamingASR():
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.original_language = lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
@@ -249,11 +251,6 @@ class SimulStreamingASR():
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
@@ -271,6 +268,12 @@ class SimulStreamingASR():
static_init_prompt=self.static_init_prompt,
)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
@@ -281,7 +284,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):
@@ -301,10 +304,12 @@ class SimulStreamingASR():
def set_translate_task(self):
"""Set up translation task."""
if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer(
multilingual=True,
language=self.model.cfg.language,
num_languages=self.model.model.num_languages,
language=self.cfg.language,
num_languages=99,
task="translate"
)

View File

@@ -0,0 +1,60 @@
# gemma_translate.py
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "google/gemma-3-270m-it"
def build_prompt(tokenizer, text, target_lang, source_lang=None):
# Use the model's chat template for best results
if source_lang:
user_msg = (
f"Translate the following {source_lang} text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
else:
user_msg = (
f"Translate the following text into {target_lang}.\n"
f"Return only the translation.\n\n"
f"Text:\n{text}"
)
chat = [{"role": "user", "content": user_msg}]
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
def translate(text, target_lang, source_lang=None, max_new_tokens=256, temperature=0.2, top_p=0.95):
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
prompt = build_prompt(tokenizer, text, target_lang, source_lang)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
eos_token_id=tokenizer.eos_token_id,
)
# Slice off the prompt to keep only the assistant answer
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
out = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return out
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Translate with google/gemma-3-270m-it")
ap.add_argument("--text", required=True, help="Text to translate")
ap.add_argument("--to", dest="target_lang", required=True, help="Target language (e.g., French, Spanish)")
ap.add_argument("--from", dest="source_lang", default=None, help="Source language (optional)")
ap.add_argument("--temp", type=float, default=0.2, help="Sampling temperature (0 = deterministic-ish)")
ap.add_argument("--max-new", type=int, default=256, help="Max new tokens")
args = ap.parse_args()
print(translate(args.text, args.target_lang, args.source_lang, max_new_tokens=args.max_new, temperature=args.temp))

View File

@@ -0,0 +1,121 @@
# nllb_translate.py
import argparse
from pathlib import Path
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = "facebook/nllb-200-distilled-600M"
# Common language shortcuts → NLLB codes (extend as needed)
LANG_MAP = {
"english": "eng_Latn",
"en": "eng_Latn",
"french": "fra_Latn",
"fr": "fra_Latn",
"spanish": "spa_Latn",
"es": "spa_Latn",
"german": "deu_Latn",
"de": "deu_Latn",
"italian": "ita_Latn",
"it": "ita_Latn",
"portuguese": "por_Latn",
"pt": "por_Latn",
"arabic": "arb_Arab",
"ar": "arb_Arab",
"russian": "rus_Cyrl",
"ru": "rus_Cyrl",
"turkish": "tur_Latn",
"tr": "tur_Latn",
"chinese": "zho_Hans",
"zh": "zho_Hans", # Simplified
"zh-cn": "zho_Hans",
"zh-hans": "zho_Hans",
"zh-hant": "zho_Hant", # Traditional
"japanese": "jpn_Jpan",
"ja": "jpn_Jpan",
"korean": "kor_Hang",
"ko": "kor_Hang",
"dutch": "nld_Latn",
"nl": "nld_Latn",
"polish": "pol_Latn",
"pl": "pol_Latn",
"swedish": "swe_Latn",
"sv": "swe_Latn",
"norwegian": "nob_Latn",
"no": "nob_Latn",
"danish": "dan_Latn",
"da": "dan_Latn",
"finnish": "fin_Latn",
"fi": "fin_Latn",
"catalan": "cat_Latn",
"ca": "cat_Latn",
"hindi": "hin_Deva",
"hi": "hin_Deva",
"vietnamese": "vie_Latn",
"vi": "vie_Latn",
"indonesian": "ind_Latn",
"id": "ind_Latn",
"thai": "tha_Thai",
"th": "tha_Thai",
}
def norm_lang(code: str) -> str:
c = code.strip().lower()
return LANG_MAP.get(c, code)
def translate_texts(texts: List[str], src_code: str, tgt_code: str,
max_new_tokens=512, device=None, dtype=None) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, src_lang=src_code)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype if dtype is not None else (torch.float16 if torch.cuda.is_available() else torch.float32),
device_map="auto" if torch.cuda.is_available() else None,
)
if device:
model.to(device)
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
if device or torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
forced_bos = tokenizer.convert_tokens_to_ids(tgt_code)
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
forced_bos_token_id=forced_bos,
)
outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in outs]
def main():
ap = argparse.ArgumentParser(description="Translate with facebook/nllb-200-distilled-600M")
ap.add_argument("--text", help="Inline text to translate")
ap.add_argument("--file", help="Path to a UTF-8 text file (one example per line)")
ap.add_argument("--src", required=True, help="Source language (e.g. fr, fra_Latn)")
ap.add_argument("--tgt", required=True, help="Target language (e.g. en, eng_Latn)")
ap.add_argument("--max-new", type=int, default=512, help="Max new tokens")
args = ap.parse_args()
src = norm_lang(args.src)
tgt = norm_lang(args.tgt)
batch: List[str] = []
if args.text:
batch.append(args.text)
if args.file:
lines = Path(args.file).read_text(encoding="utf-8").splitlines()
batch.extend([ln for ln in lines if ln.strip()])
if not batch:
raise SystemExit("Provide --text or --file")
results = translate_texts(batch, src, tgt, max_new_tokens=args.max_new)
for i, (inp, out) in enumerate(zip(batch, results), 1):
print(f"\n--- Sample {i} ---")
print(f"SRC [{src}]: {inp}")
print(f"TGT [{tgt}]: {out}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import regex
from functools import lru_cache
class SentenceSegmenter:
"""
Regex sentence splitter for Latin languages, Japanese and Chinese.
It is based on sacrebleu TokenizerV14International(BaseTokenizer).
Returns: a list of strings, where each string is a sentence.
Spaces following punctuation are appended after punctuation within the sequence.
Total number of characters in the output is the same as in the input.
"""
sep = 'ŽžŽžSentenceSeparatorŽžŽž' # string that certainly won't be in src or target
latin_terminals = '!?.'
jap_zh_terminals = '。!?'
terminals = latin_terminals + jap_zh_terminals
def __init__(self):
# end of sentence characters:
terminals = self.terminals
self._re = [
# Separate out punctuations preceeded by a non-digit.
# If followed by space-like sequence of characters, they are
# appended to the punctuation, not to the next sequence.
(regex.compile(r'(\P{N})(['+terminals+r'])(\p{Z}*)'), r'\1\2\3'+self.sep),
# Separate out punctuations followed by a non-digit
(regex.compile(r'('+terminals+r')(\P{N})'), r'\1'+self.sep+r'\2'),
# # Separate out symbols
# -> no, we don't tokenize but segment the punctuation
# (regex.compile(r'(\p{S})'), r' \1 '),
]
@lru_cache(maxsize=2**16)
def __call__(self, line):
for (_re, repl) in self._re:
line = _re.sub(repl, line)
return [ t for t in line.split(self.sep) if t != '' ]

View File

@@ -0,0 +1,466 @@
import sys
import ctranslate2
import sentencepiece as spm
import transformers
import argparse
def generate_words(sp, step_results):
tokens_buffer = []
for step_result in step_results:
is_new_word = step_result.token.startswith("")
if is_new_word and tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
tokens_buffer = []
tokens_buffer.append(step_result.token_id)
if tokens_buffer:
word = sp.decode(tokens_buffer)
if word:
yield word
from sentence_segmenter import SentenceSegmenter
class LLMTranslator:
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None):
self.system_prompt = system_prompt
print("Loading the model...", file=sys.stderr)
self.generator = ctranslate2.Generator("ct2_EuroLLM-9B-Instruct/", device="cuda")
self.sp = spm.SentencePieceProcessor("EuroLLM-9B-Instruct/tokenizer.model")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("EuroLLM-9B-Instruct/")
print("...done", file=sys.stderr)
self.max_context_length = max_context_length
self.max_tokens_to_trim = self.max_context_length - 10
self.len_ratio = len_ratio
# my regex sentence segmenter
self.segmenter = SentenceSegmenter()
# self.max_generation_length = 512
# self.max_prompt_length = context_length - max_generation_length
def start_dialog(self):
return [{'role':'system', 'content': self.system_prompt }]
def build_prompt(self, dialog):
toks = self.tokenizer.apply_chat_template(dialog, tokenize=True, add_generation_prompt=False)
if len(dialog) == 3:
toks = toks[:-2]
print("len toks:", len(toks), file=sys.stderr)
# print(toks, file=sys.stderr)
c = self.tokenizer.convert_ids_to_tokens(toks)
# print(c,file=sys.stderr)
return c
def translate(self, src, tgt_forced=""):
#src, tgt_forced = self.trim(src, tgt_forced)
dialog = self.start_dialog()
dialog += [{'role':'user','content': src}]
if tgt_forced != "":
dialog += [{'role':'assistant','content': tgt_forced}]
prompt_tokens = self.build_prompt(dialog)
if self.len_ratio is not None:
limit_len = int(len(self.tokenizer.encode(src)) * self.len_ratio) + 10
limit_kw = {'max_length': limit_len}
else:
limit_kw = {}
step_results = self.generator.generate_tokens(
prompt_tokens,
**limit_kw,
# end_token=tokenizer.eos_token,
# sampling_temperature=0.6,
# sampling_topk=20,
# sampling_topp=1,
)
res = []
#output_ids = []
for step_result in step_results:
# is_new_word = step_result.token.startswith("▁")
# if is_new_word and output_ids:
# word = self.sp.decode(output_ids)
# print(word, end=" ", flush=True, file=sys.stderr)
# output_ids = []
# output_ids.append(step_result.token_id)
res.append(step_result)
#if output_ids:
# word = self.sp.decode(output_ids)
# print(word, file=sys.stderr)
return self.sp.decode([r.token_id for r in res])
# print(res)
# print([s.token for s in res], file=sys.stderr)
# print([s.token==self.tokenizer.eos_token for s in res], file=sys.stderr)
class ParallelTextBuffer:
def __init__(self, tokenizer, max_tokens, trimming="segments", init_src="", init_tgt=""):
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.src_buffer = [] # list of lists
if init_src:
self.src_buffer.append(init_src)
self.tgt_buffer = [] # list of strings
if init_tgt:
self.tgt_buffer.append(init_tgt)
self.trimming = trimming
if self.trimming == "sentences":
self.segmenter = SentenceSegmenter()
def len_src(self):
return sum(len(t) for t in self.src_buffer) + len(self.src_buffer) - 1
def insert(self, src, tgt):
self.src_buffer.append(src)
self.tgt_buffer.append(tgt)
def insert_src_suffix(self, s):
if self.src_buffer:
self.src_buffer[-1][-1] += s
else:
self.src_buffer.append([s])
def trim_sentences(self):
# src_tok_lens = [len(self.tokenizer.encode(" ".join(b))) for b in self.src_buffer]
# tgt_tok_lens = [len(self.tokenizer.encode(t)) for t in self.tgt_buffer]
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
def trim_sentence(text):
sents = self.segmenter(text)
print("SENTS:", len(sents), sents, file=sys.stderr)
return "".join(sents[1:])
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
nsrc = trim_sentence(src)
ntgt = trim_sentence(tgt)
if not nsrc or not ntgt:
print("src or tgt is empty after trimming.", file=sys.stderr)
print("src: ", src, file=sys.stderr)
print("tgt: ", tgt, file=sys.stderr)
break
src = nsrc
tgt = ntgt
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SRC:", (src,), file=sys.stderr)
print("TRIMMED TGT:", (tgt,), file=sys.stderr)
self.src_buffer = [src.split()]
self.tgt_buffer = [tgt]
return src, tgt
def trim_segments(self):
print("BUFFER:", file=sys.stderr)
for s,t in zip(self.src_buffer, self.tgt_buffer):
print("\t", s,"...",t,file=sys.stderr) #,self.src_buffer, self.tgt_buffer, file=sys.stderr)
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
if len(self.src_buffer) > 1 and len(self.tgt_buffer) > 1:
self.src_buffer.pop(0)
self.tgt_buffer.pop(0)
else:
break
src = " ".join(" ".join(b) for b in self.src_buffer)
tgt = "".join(self.tgt_buffer)
src_sp_toks = self.tokenizer.encode(src)
tgt_sp_toks = self.tokenizer.encode(tgt)
print("TRIMMED SEGMENTS SRC:", (src,), file=sys.stderr)
print("TRIMMED SEGMENTS TGT:", (tgt,), file=sys.stderr)
return src, tgt
def trim(self):
if self.trimming == "sentences":
return self.trim_sentences()
return self.trim_segments()
class SimulLLM:
def __init__(self, llmtrans, min_len=0, chunk=1, trimming="sentences", language="ja", init_src="", init_tgt=""):
self.llmtranslator = llmtrans
#self.src_buffer = init_src
#self.confirmed_tgt = init_tgt
self.buffer = ParallelTextBuffer(self.llmtranslator.tokenizer, self.llmtranslator.max_tokens_to_trim, trimming=trimming, init_src=init_src, init_tgt=init_tgt)
self.last_inserted = []
self.last_unconfirmed = ""
self.min_len = min_len
self.step = chunk
self.language = language
if language in ["ja", "zh"]:
self.specific_space = ""
else:
self.specific_space = " "
def insert(self, src):
if isinstance(src, str):
self.last_inserted.append(src)
else:
self.last_inserted += src
def insert_suffix(self, text):
'''
Insert suffix of a word to the last inserted word.
It may be because the word was split to multiple parts in the input, each with different timestamps.
'''
if self.last_inserted:
self.last_inserted[-1] += text
elif self.src_buffer:
self.buffer.insert_src_suffix(text)
else:
# this shouldn't happen
self.last_inserted.append(text)
def trim_longest_common_prefix(self, a,b):
if self.language not in ["ja", "zh"]:
a = a.split()
b = b.split()
i = 0
for i,(x,y) in enumerate(zip(a,b)):
if x != y:
break
if self.language in ["ja", "zh"]:
#print("tady160",(a, b, i), file=sys.stderr)
return a[:i], b[i:]
else:
return " ".join(a[:i]), " ".join(b[i:])
def process_iter(self):
if self.buffer.len_src() + len(self.last_inserted) < self.min_len:
return ""
src, forced_tgt = self.buffer.trim() #llmtranslator.trim(" ".join(self.src_buffer), self.confirmed_tgt)
#self.src_buffer = self.src_buffer.split()
#src = " ".join(self.src_buffer)
confirmed_out = ""
run = False
for i in range(self.step, len(self.last_inserted), self.step):
for w in self.last_inserted[i-self.step:i]:
src += " " + w
run = True
if not run: break
print("SRC",src,file=sys.stderr)
print("FORCED TGT",forced_tgt,file=sys.stderr)
out = self.llmtranslator.translate(src, forced_tgt)
print("OUT",out,file=sys.stderr)
confirmed, unconfirmed = self.trim_longest_common_prefix(self.last_unconfirmed, out)
self.last_unconfirmed = unconfirmed
#print("tady", (self.confirmed_tgt, self.specific_space, confirmed), file=sys.stderr)
if confirmed:
# self.confirmed_tgt += self.specific_space + confirmed
# print(confirmed_out, confirmed, file=sys.stderr)
confirmed_out += self.specific_space + confirmed
print("CONFIRMED NOW:",confirmed,file=sys.stderr)
print(file=sys.stderr)
print(file=sys.stderr)
print("#################",file=sys.stderr)
if run:
self.buffer.insert(self.last_inserted, confirmed_out)
self.last_inserted = []
ret = confirmed_out
print("RET:",ret,file=sys.stderr)
return ret
def finalize(self):
return self.last_unconfirmed
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input-instance', type=str, default=None, help="Filename of instances to simulate input. If not set, txt input is read from stdin.")
#parser.add_argument('--output_instance', type=str, default=None, help="Write output as instance into this file, while also writing to stdout.")
parser.add_argument('--min-chunk-size', type=int, default=1,
help='Minimum number of space-delimited words to process in each LocalAgreement update. The more, the higher quality, but slower.')
parser.add_argument('--min-len', type=int, default=1,
help='Minimum number of space-delimited words at the beginning.')
#parser.add_argument('--start_at', type=int, default=0, help='Skip first N words.')
# maybe later
#parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
#parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
lan_to_name = {
"de": "German",
"ja": "Japanese",
"zh-tr": "Chinese Traditional",
"zh-sim": "Chinese Simplified",
"cs": "Czech",
}
parser.add_argument('--lan', '--language', type=str, default="de",
help="Target language code.",
choices=["de", "ja","zh-tr","zh-sim","cs"])
SrcLang = "English" # always
TgtLang = "German"
default_prompt="You are simultaneous interpreter from {SrcLang} to {TgtLang}. We are at a conference. It is important that you translate " + \
"only what you hear, nothing else!"
parser.add_argument('--sys_prompt', type=str, default=None,
help='System prompt. If None, default one is used, depending on the language. The prompt should ')
default_init = "Please, go ahead, you can start with your presentation, we are ready."
default_inits_tgt = {
'de': "Bitte schön, Sie können mit Ihrer Präsentation beginnen, wir sind bereit.",
'ja': "どうぞ、プレゼンテーションを始めてください。", # # Please go ahead and start your presentation. # this is in English
'zh-tr': "請繼續,您可以開始您的簡報,我們已經準備好了。",
'zh-sim': "请吧,你可以开始发言了,我们已经准备好了。",
'cs': "Prosím, můžete začít s prezentací, jsme připraveni.",
}
parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
'It can be context specific for the given input. Default is ')
parser.add_argument('--init_prompt_tgt', type=str, default=None, help='Init translation with this target. It should be example translation of init_prompt_src. '
' There is default init message, depending on the language.')
parser.add_argument('--len-threshold', type=float, default=None, help='Ratio of the length of the source and generated target, in number of sentencepiece tokens. '
'It should reflect the target language and. If not set, no len-threshold is used.')
# how many times is target text longer than English
lan_thresholds = {
'de': 1.3, # 12751/9817 ... the proportion of subword tokens for ACL6060 dev de vs. en text, for EuroLLM-9B-Instruct tokenizer
'ja': 1.34, # 13187/9817
'zh': 1.23, # 12115/9817
'zh-tr': 1.23, # 12115/9817
'zh-sim': 1.23, # 12115/9817
# 'cs': I don't know # guessed
}
parser.add_argument('--language-specific-len-threshold', default=False, action="store_true",
help='Use language-specific length threshold, e.g. 1.3 for German.')
parser.add_argument("--max-context-length", type=int, default=4096, help="Maximum number of tokens in the model to use.")
parser.add_argument("--buffer_trimming", type=str, default="sentences", choices=["segments","sentences"], help="Buffer trimming strategy.")
args = parser.parse_args()
if args.sys_prompt is None:
TgtLang = lan_to_name[args.lan]
sys_prompt = default_prompt.format(SrcLang=SrcLang, TgtLang=TgtLang)
else:
sys_prompt = args.sys_prompt
if args.init_prompt_src is None:
init_src = default_init.split()
if args.init_prompt_tgt is None:
init_tgt = default_inits_tgt[args.lan]
if args.lan == "ja":
init_src = 'Please go ahead and start your presentation.'.split()
print("WARNING: Default init_prompt_src not set and language is Japanese. The init_src prompt changed to be more verbose.", file=sys.stderr)
else:
print("WARNING: init_prompt_tgt is used, init_prompt_src is None, the default one. It may be wrong!", file=sys.stderr)
init_tgt = args.init_prompt_tgt
else:
init_src = args.init_prompt_src.split()
if args.init_prompt_tgt is None:
print("WARNING: init_prompt_src is used, init_prompt_tgt is None, so the default one is used. It may be wrong!", file=sys.stderr)
init_tgt = default_inits_tgt[args.lan]
else:
init_tgt = args.init_prompt_tgt
print("INFO: System prompt:", sys_prompt, file=sys.stderr)
print("INFO: Init prompt src:", init_src, file=sys.stderr)
print("INFO: Init prompt tgt:", init_tgt, file=sys.stderr)
if args.language_specific_len_threshold:
if args.len_threshold is not None:
print("ERROR: --len-threshold is set, but --language-specific-len-threshold is also set. Only one can be used.", file=sys.stderr)
sys.exit(1)
else:
len_threshold = lan_thresholds[args.lan]
else:
len_threshold = args.len_threshold
llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold)
lan = args.lan if not args.lan.startswith("zh") else "zh"
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
)
# two input options
if args.input_instance is not None:
print("INFO: Reading input from file", args.input_instance, file=sys.stderr)
import json
with open(args.input_instance, "r") as f:
instance = json.load(f)
asr_source = instance["prediction"]
timestamps = instance["delays"]
elapsed = instance["elapsed"]
yield_ts_words = zip(timestamps, timestamps, elapsed, asr_source.split())
else:
print("INFO: Reading stdin in txt format", file=sys.stderr)
def yield_input():
for line in sys.stdin:
line = line.strip()
ts, beg, end, *_ = line.split()
text = line[len(ts)+len(beg)+len(end)+3:]
ts = float(ts)
# in rare cases, the first word is a suffix of the previous word, that was split to multiple parts
if text[0] != " ":
first, *words = text.split()
yield (ts, beg, end, " "+first) # marking the first word with " ", so that it can be later detected and inserted as suffix
else:
words = text.split()
for w in words:
yield (ts, beg, end, w)
yield_ts_words = yield_input()
#i = 0
for t,b,e,w in yield_ts_words:
if w.startswith(" "): # it is suffix of the previous word
w = w[1:]
simul.insert_suffix(w)
continue
simul.insert(w)
out = simul.process_iter()
if out:
print(t,b,e,out,flush=True)
# if i > 50:
# break
# i += 1
out = simul.finalize()
print(t,b,e,out,flush=True)

View File

@@ -400,7 +400,12 @@ async function startRecording() {
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
if (window.location.hostname === "0.0.0.0") {
statusText.textContent =
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
} else {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
console.error(err);
}
}

View File

@@ -1,5 +1,6 @@
import logging
import importlib.resources as resources
import base64
logger = logging.getLogger(__name__)
@@ -12,6 +13,67 @@ def get_web_interface_html():
logger.error(f"Error loading web interface HTML: {e}")
return "<html><body><h1>Error loading interface</h1></body></html>"
def get_inline_ui_html():
"""Returns the complete web interface HTML with all assets embedded in a single call."""
try:
# Load HTML template
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
html_content = f.read()
# Load CSS and embed it
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
css_content = f.read()
# Load JS and embed it
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
js_content = f.read()
# Load SVG files and convert to data URIs
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
system_svg = f.read()
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
light_svg = f.read()
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
dark_svg = f.read()
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
# Replace external references with embedded content
html_content = html_content.replace(
'<link rel="stylesheet" href="/web/live_transcription.css" />',
f'<style>\n{css_content}\n</style>'
)
html_content = html_content.replace(
'<script src="/web/live_transcription.js"></script>',
f'<script>\n{js_content}\n</script>'
)
# Replace SVG references with data URIs
html_content = html_content.replace(
'<img src="/web/src/system_mode.svg" alt="" />',
f'<img src="{system_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/light_mode.svg" alt="" />',
f'<img src="{light_data_uri}" alt="" />'
)
html_content = html_content.replace(
'<img src="/web/src/dark_mode.svg" alt="" />',
f'<img src="{dark_data_uri}" alt="" />'
)
return html_content
except Exception as e:
logger.error(f"Error creating embedded web interface: {e}")
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
if __name__ == '__main__':
@@ -28,6 +90,6 @@ if __name__ == '__main__':
@app.get("/")
async def get():
return HTMLResponse(get_web_interface_html())
return HTMLResponse(get_inline_ui_html())
uvicorn.run(app=app)
uvicorn.run(app=app)