mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-09 15:25:34 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab98c31f16 | ||
|
|
f9c9c4188a | ||
|
|
c21d2302e7 | ||
|
|
4ed62e181d | ||
|
|
52a755a08c | ||
|
|
9a8d3cbd90 | ||
|
|
b101ce06bd | ||
|
|
c83fd179a8 | ||
|
|
5258305745 | ||
|
|
ce781831ee | ||
|
|
58297daf6d | ||
|
|
3393a08f7e | ||
|
|
5b2ddeccdb | ||
|
|
26cc1072dd |
20
Dockerfile
20
Dockerfile
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
|
||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
@@ -9,24 +9,20 @@ ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
# Install system dependencies
|
||||
#RUN apt-get update && \
|
||||
# apt-get install -y ffmpeg git && \
|
||||
# apt-get clean && \
|
||||
# rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 2) Install system dependencies + Python + pip
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129
|
||||
|
||||
COPY . .
|
||||
|
||||
@@ -35,10 +31,10 @@ COPY . .
|
||||
# for more details.
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir .[$EXTRAS]; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir .; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models by:
|
||||
@@ -81,4 +77,4 @@ EXPOSE 8000
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args
|
||||
CMD ["--model", "base"]
|
||||
CMD ["--model", "medium"]
|
||||
61
Dockerfile.cpu
Normal file
61
Dockerfile.cpu
Normal file
@@ -0,0 +1,61 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CPU-only PyTorch
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
# Conditionally copy a local pre-cache from the build context
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args - you might want to use a smaller model for CPU
|
||||
CMD ["--model", "tiny"]
|
||||
69
README.md
69
README.md
@@ -8,7 +8,7 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
@@ -66,11 +66,12 @@ pip install whisperlivekit
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| Speaker diarization | `whisperlivekit[diarization]` |
|
||||
| Original Whisper backend | `whisperlivekit[whisper]` |
|
||||
| Improved timestamps backend | `whisperlivekit[whisper-timestamped]` |
|
||||
| Apple Silicon optimization backend | `whisperlivekit[mlx-whisper]` |
|
||||
| OpenAI API backend | `whisperlivekit[openai]` |
|
||||
| Speaker diarization with Sortformer | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Speaker diarization with Diart | `diart` |
|
||||
| Original Whisper backend | `whisper` |
|
||||
| Improved timestamps backend | `whisper-timestamped` |
|
||||
| Apple Silicon optimization backend | `mlx-whisper` |
|
||||
| OpenAI API backend | `openai` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
@@ -91,10 +92,10 @@ See **Parameters & Configuration** below on how to use them.
|
||||
Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# SimulStreaming backend for ultra-low latency
|
||||
whisperlivekit-server --backend simulstreaming --model large-v3
|
||||
# Use better model than default (small)
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
# Advanced configuration with diarization
|
||||
# Advanced configuration with diarization and language
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
@@ -145,6 +146,16 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
|
||||
|
||||
### ⚙️ Parameters & Configuration
|
||||
|
||||
An important list of parameters can be changed. But what *should* you change?
|
||||
- the `--model` size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)
|
||||
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
|
||||
- `--warmup-file`, if you have one
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
|
||||
- `--diarization`, if you want to use it.
|
||||
|
||||
The rest I don't recommend. But below are your options.
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. | `small` |
|
||||
@@ -185,9 +196,9 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
@@ -216,19 +227,39 @@ To deploy WhisperLiveKit in production:
|
||||
|
||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||
|
||||
### 🐋 Docker
|
||||
## 🐋 Docker
|
||||
|
||||
A Dockerfile is provided which allows re-use of Python package installation options. Create a reusable image with only the basics and then run as a named container:
|
||||
Deploy the application easily using Docker with GPU or CPU support.
|
||||
|
||||
### Prerequisites
|
||||
- Docker installed on your system
|
||||
- For GPU support: NVIDIA Docker runtime installed
|
||||
|
||||
### Quick Start
|
||||
|
||||
**With GPU acceleration (recommended):**
|
||||
```bash
|
||||
docker build -t whisperlivekit-defaults .
|
||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults --model base
|
||||
docker start -i whisperlivekit
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
> **Note**: For **large** models, ensure that your **docker runtime** has enough **memory** available
|
||||
**CPU only:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
|
||||
**Custom configuration:**
|
||||
```bash
|
||||
# Example with custom model and language
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### Memory Requirements
|
||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||
|
||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
||||
|
||||
#### Customization
|
||||
|
||||
|
||||
72
available_models.md
Normal file
72
available_models.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# Available model sizes:
|
||||
|
||||
- tiny.en (english only)
|
||||
- tiny
|
||||
- base.en (english only)
|
||||
- base
|
||||
- small.en (english only)
|
||||
- small
|
||||
- medium.en (english only)
|
||||
- medium
|
||||
- large-v1
|
||||
- large-v2
|
||||
- large-v3
|
||||
- large-v3-turbo
|
||||
|
||||
## How to choose?
|
||||
|
||||
### Language Support
|
||||
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
|
||||
- **Multilingual**: Do not use `.en` models.
|
||||
|
||||
### Resource Constraints
|
||||
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
|
||||
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
|
||||
- `base`: Good balance of speed and accuracy for basic use cases
|
||||
- `small`: Better accuracy while still being resource-efficient
|
||||
- **Good resources available**: Use `large` models for best accuracy
|
||||
- `large-v2`: Excellent accuracy, good multilingual support
|
||||
- `large-v3`: Best overall accuracy and language support
|
||||
|
||||
### Special Cases
|
||||
- **No translation needed**: Use `large-v3-turbo`
|
||||
- Same transcription quality as `large-v2` but significantly faster
|
||||
- **Important**: Does not translate correctly, only transcribes
|
||||
|
||||
### Model Comparison Table
|
||||
|
||||
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|
||||
|-------|--------|----------|--------------|-------------|---------------|
|
||||
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
|
||||
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
|
||||
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
|
||||
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
|
||||
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
|
||||
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
|
||||
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
|
||||
|
||||
### Additional Considerations
|
||||
|
||||
**Model Performance**:
|
||||
- Accuracy improves significantly from tiny to large models
|
||||
- English-only models are ~10-15% more accurate for English audio
|
||||
- Newer versions (v2, v3) have better punctuation and formatting
|
||||
|
||||
**Hardware Requirements**:
|
||||
- `tiny`: ~1GB VRAM
|
||||
- `base`: ~1GB VRAM
|
||||
- `small`: ~2GB VRAM
|
||||
- `medium`: ~5GB VRAM
|
||||
- `large`: ~10GB VRAM
|
||||
|
||||
**Audio Quality Impact**:
|
||||
- Clean, clear audio: smaller models may suffice
|
||||
- Noisy, accented, or technical audio: larger models recommended
|
||||
- Phone/low-quality audio: use at least `small` model
|
||||
|
||||
### Quick Decision Tree
|
||||
1. English only? → Add `.en` to your choice
|
||||
2. Limited resources or need speed? → `small` or smaller
|
||||
3. Good hardware and want best quality? → `large-v3`
|
||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.6"
|
||||
version = "0.2.7"
|
||||
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -35,12 +35,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diarization = ["diart"]
|
||||
sentence = ["mosestokenizer", "wtpsplit"]
|
||||
whisper = ["whisper"]
|
||||
whisper-timestamped = ["whisper-timestamped"]
|
||||
mlx-whisper = ["mlx-whisper"]
|
||||
openai = ["openai"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
@@ -4,13 +4,11 @@ from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.trail_repetition import trim_tail_repetition
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output, format_time
|
||||
# Set up logging once
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,10 +16,6 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
@@ -66,7 +60,6 @@ class AudioProcessor:
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
self.diarization = models.diarization
|
||||
self.vac_model = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
@@ -99,6 +92,11 @@ class AudioProcessor:
|
||||
# Initialize transcription engine if enabled
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
|
||||
# Initialize diarization engine if enabled
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
@@ -108,17 +106,6 @@ class AudioProcessor:
|
||||
"""Thread-safe update of transcription with new data."""
|
||||
async with self.lock:
|
||||
self.tokens.extend(new_tokens)
|
||||
|
||||
# self.tokens, has_been_trimmed = trim_tail_repetition(
|
||||
# self.tokens,
|
||||
# key=lambda t: t.text.strip().lower(),
|
||||
# min_block=2, # avoid trimming single '.' loops; set to 1 if you want to remove those too
|
||||
# max_tail=200,
|
||||
# prefer="longest", # prefer removing the longest repeated phrase
|
||||
# keep=1
|
||||
# )
|
||||
# if has_been_trimmed:
|
||||
# print('HAS BEEN TRIMMED !')
|
||||
self.buffer_transcription = buffer
|
||||
self.end_buffer = end_buffer
|
||||
self.sep = sep
|
||||
@@ -303,7 +290,7 @@ class AudioProcessor:
|
||||
if type(item) is Silence:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
if self.tokens:
|
||||
asr_processing_logs += " | last_end = {self.tokens[-1].end} |"
|
||||
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if type(item) is Silence:
|
||||
@@ -433,7 +420,7 @@ class AudioProcessor:
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
|
||||
# Add dummy tokens if needed
|
||||
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
||||
await self.add_dummy_token()
|
||||
@@ -442,45 +429,13 @@ class AudioProcessor:
|
||||
tokens = state["tokens"]
|
||||
|
||||
# Format output
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, self.silence)
|
||||
for token in tokens:
|
||||
speaker = token.speaker
|
||||
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
|
||||
# Handle diarization
|
||||
if self.args.diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
debug_info = ""
|
||||
if self.debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
if speaker != previous_speaker or not lines:
|
||||
lines.append({
|
||||
"speaker": speaker,
|
||||
"text": token.text + debug_info,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
})
|
||||
previous_speaker = speaker
|
||||
elif token.text: # Only append if text isn't empty
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
diarization = self.args.diarization,
|
||||
debug = self.debug
|
||||
)
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = sep.join(undiarized_text)
|
||||
@@ -510,7 +465,7 @@ class AudioProcessor:
|
||||
"buffer_transcription": buffer_transcription,
|
||||
"buffer_diarization": buffer_diarization,
|
||||
"remaining_time_transcription": state["remaining_time_transcription"],
|
||||
"remaining_time_diarization": state["remaining_time_diarization"]
|
||||
"remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0
|
||||
}
|
||||
|
||||
current_response_signature = f"{response_status} | " + \
|
||||
|
||||
@@ -52,7 +52,7 @@ async def handle_websocket_results(websocket, results_generator):
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket results handler: {e}")
|
||||
logger.exception(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
|
||||
@@ -57,7 +57,7 @@ class TranscriptionEngine:
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"model_path": './base.pt',
|
||||
"diarization_backend": "diart",
|
||||
"diarization_backend": "sortformer",
|
||||
# diart params:
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
@@ -121,13 +121,14 @@ class TranscriptionEngine:
|
||||
if self.args.diarization:
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
self.diarization = DiartDiarization(
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
segmentation_model_name=self.args.segmentation_model,
|
||||
embedding_model_name=self.args.embedding_model
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
raise ValueError('Sortformer backend in developement')
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
else:
|
||||
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
|
||||
|
||||
@@ -152,4 +153,16 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
||||
confidence_validation = args.confidence_validation
|
||||
)
|
||||
return online
|
||||
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
if args.diarization_backend == "diart":
|
||||
online = diarization_backend
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
|
||||
|
||||
if args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
return online
|
||||
|
||||
|
||||
@@ -1,145 +1,457 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from typing import List, Optional
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
except ImportError:
|
||||
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.spkcache = None # Speaker cache to store embeddings from start
|
||||
self.spkcache_lengths = None
|
||||
self.spkcache_preds = None # speaker cache predictions
|
||||
self.fifo = None # to save the embedding from the latest chunks
|
||||
self.fifo_lengths = None
|
||||
self.fifo_preds = None
|
||||
self.spk_perm = None
|
||||
self.mean_sil_emb = None
|
||||
self.n_sil_frames = None
|
||||
|
||||
|
||||
class SortformerDiarization:
|
||||
def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||
"""
|
||||
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||
"""
|
||||
self._load_model(model_name)
|
||||
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and configure the Sortformer model for streaming."""
|
||||
try:
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.diar_model.to(torch.device("cuda"))
|
||||
if torch.cuda.is_available():
|
||||
self.diar_model.to(torch.device("cuda"))
|
||||
logger.info("Using CUDA for Sortformer model")
|
||||
else:
|
||||
logger.info("Using CPU for Sortformer model")
|
||||
|
||||
# Streaming parameters for speed
|
||||
self.diar_model.sortformer_modules.chunk_len = 12
|
||||
self.diar_model.sortformer_modules.chunk_right_context = 1
|
||||
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||
self.diar_model.sortformer_modules.fifo_len = 188
|
||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
self.diar_model.sortformer_modules.log = False
|
||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
self.batch_size = 1
|
||||
self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device)
|
||||
self.diar_model.sortformer_modules.chunk_len = 10
|
||||
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||
self.diar_model.sortformer_modules.chunk_right_context = 0
|
||||
self.diar_model.sortformer_modules.chunk_left_context = 10
|
||||
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||
self.diar_model.sortformer_modules.fifo_len = 188
|
||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
self.diar_model.sortformer_modules.log = False
|
||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Sortformer model: {e}")
|
||||
raise
|
||||
|
||||
class SortformerDiarizationOnline:
|
||||
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize the streaming Sortformer diarization system.
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.sample_rate = 16000
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.speaker_segments = []
|
||||
|
||||
self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state(
|
||||
batch_size=self.batch_size,
|
||||
async_streaming=True,
|
||||
device=self.diar_model.device
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.processed_time = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
|
||||
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size=0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0
|
||||
)
|
||||
self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device)
|
||||
|
||||
|
||||
def _prepare_audio_signal(self, signal):
|
||||
audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device)
|
||||
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device)
|
||||
processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)
|
||||
return processed_signal, processed_signal_length
|
||||
|
||||
def _create_streaming_loader(self, processed_signal, processed_signal_length):
|
||||
streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader(
|
||||
feat_seq=processed_signal,
|
||||
feat_seq_length=processed_signal_length,
|
||||
feat_seq_offset=self.processed_signal_offset,
|
||||
|
||||
self.chunk_duration_seconds = (
|
||||
self.diar_model.sortformer_modules.chunk_len *
|
||||
self.diar_model.sortformer_modules.subsampling_factor *
|
||||
self.diar_model.preprocessor._cfg.window_stride
|
||||
)
|
||||
return streaming_loader
|
||||
|
||||
self._init_streaming_state()
|
||||
|
||||
self._previous_chunk_features = None
|
||||
self._chunk_index = 0
|
||||
self._len_prediction = None
|
||||
|
||||
# Audio buffer to store PCM chunks for debugging
|
||||
self.audio_buffer = []
|
||||
|
||||
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||
self.audio_chunk_buffer = []
|
||||
self.accumulated_duration = 0.0
|
||||
|
||||
logger.info("SortformerDiarization initialized successfully")
|
||||
|
||||
|
||||
def _init_streaming_state(self):
|
||||
"""Initialize the streaming state for the model."""
|
||||
batch_size = 1
|
||||
device = self.diar_model.device
|
||||
|
||||
self.streaming_state = StreamingSortformerState()
|
||||
self.streaming_state.spkcache = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_preds = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.fifo = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
|
||||
# Initialize total predictions tensor
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
Args:
|
||||
silence_duration: Duration of silence in seconds
|
||||
"""
|
||||
with self.segment_lock:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process an incoming audio chunk for diarization.
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array])
|
||||
|
||||
# Process in fixed-size chunks (e.g., 1 second)
|
||||
chunk_size = self.sample_rate # 1 second of audio
|
||||
|
||||
while len(self.audio_buffer) >= chunk_size:
|
||||
chunk_to_process = self.audio_buffer[:chunk_size]
|
||||
self.audio_buffer = self.audio_buffer[chunk_size:]
|
||||
try:
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
|
||||
processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process)
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride
|
||||
|
||||
streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length)
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return
|
||||
|
||||
frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride
|
||||
chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:]
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total_features = processed_signal_chunk
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
|
||||
# Convert predictions to speaker segments
|
||||
self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in diarize: {e}")
|
||||
raise
|
||||
|
||||
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
||||
|
||||
for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:
|
||||
with torch.inference_mode():
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=feat_lengths,
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
try:
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers)
|
||||
|
||||
# Get predictions for current chunk
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
with self.segment_lock:
|
||||
# Process predictions into segments
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
start_time = base_time + idx * frame_duration
|
||||
end_time = base_time + (idx + 1) * frame_duration
|
||||
|
||||
num_new_frames = feat_lengths[0].item()
|
||||
|
||||
# Get predictions for the current chunk from the end of total_preds
|
||||
preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s)
|
||||
end_time = start_time + frame_duration_s
|
||||
# Check if this continues the last segment or starts a new one
|
||||
if (self.speaker_segments and
|
||||
self.speaker_segments[-1].speaker == spk and
|
||||
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
||||
# Continue existing segment
|
||||
self.speaker_segments[-1].end = end_time
|
||||
else:
|
||||
|
||||
if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1:
|
||||
self.speaker_segments[-1].end = end_time
|
||||
else:
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=int(spk + 1),
|
||||
start=start_time,
|
||||
end=end_time
|
||||
))
|
||||
|
||||
self.processed_signal_offset += processed_signal_length
|
||||
# Create new segment
|
||||
self.speaker_segments.append(SpeakerSegment(
|
||||
speaker=spk,
|
||||
start=start_time,
|
||||
end=end_time
|
||||
))
|
||||
|
||||
# Update processed time
|
||||
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||
|
||||
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing predictions: {e}")
|
||||
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list:
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
for token in tokens:
|
||||
for segment in self.speaker_segments:
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
tokens: List of tokens to assign speakers to
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
|
||||
# Convert segments to concatenated format
|
||||
segments_concatenated = self._concatenate_speakers(segments)
|
||||
|
||||
# Adjust segment boundaries based on punctuation
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""
|
||||
Concatenate consecutive segments from the same speaker.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
|
||||
Returns:
|
||||
List of concatenated speaker segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = segment.speaker + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.speaker_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.speaker_segments = [
|
||||
segment for segment in self.speaker_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Cleanup resources.
|
||||
"""
|
||||
logger.info("Closing SortformerDiarization.")
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.speaker_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
"""Extract number from speaker string (compatibility function)."""
|
||||
import re
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
import librosa
|
||||
an4_audio = 'new_audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
diarization_pipeline = SortformerDiarization()
|
||||
|
||||
# Simulate streaming
|
||||
chunk_size = 16000 # 1 second
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
import asyncio
|
||||
asyncio.run(diarization_pipeline.diarize(chunk))
|
||||
|
||||
for segment in diarization_pipeline.speaker_segments:
|
||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||
print("\n" + "=" * 50)
|
||||
print("ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization = SortformerDiarization(sample_rate=16000)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
for segment in segments:
|
||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import math
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
except ImportError:
|
||||
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||
|
||||
|
||||
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
diar_model.to(torch.device("cuda"))
|
||||
|
||||
# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
|
||||
# diar_model.sortformer_modules.chunk_len = 6
|
||||
# diar_model.sortformer_modules.spkcache_len = 188
|
||||
# diar_model.sortformer_modules.chunk_right_context = 7
|
||||
# diar_model.sortformer_modules.fifo_len = 188
|
||||
# diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
# diar_model.sortformer_modules.log = False
|
||||
|
||||
|
||||
# here we change the settings for our goal: speed!
|
||||
# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12.
|
||||
diar_model.sortformer_modules.chunk_len = 12
|
||||
|
||||
# for more speed, we reduce the 'right context'. it's like looking less into the future.
|
||||
diar_model.sortformer_modules.chunk_right_context = 1
|
||||
|
||||
# we keep the rest same for now
|
||||
diar_model.sortformer_modules.spkcache_len = 188
|
||||
diar_model.sortformer_modules.fifo_len = 188
|
||||
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
diar_model.sortformer_modules.log = False
|
||||
diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
batch_size = 1
|
||||
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)
|
||||
|
||||
# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
|
||||
# from nemo.collections.asr.modules.audio_preprocessing import get_features
|
||||
from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor
|
||||
|
||||
|
||||
def prepare_audio_signal(signal):
|
||||
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
|
||||
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128).get_features(audio_signal, audio_signal_length)
|
||||
return processed_signal, processed_signal_length
|
||||
|
||||
|
||||
def streaming_feat_loader(
|
||||
feat_seq, feat_seq_length, feat_seq_offset
|
||||
):
|
||||
"""
|
||||
Load a chunk of feature sequence for streaming inference.
|
||||
|
||||
Args:
|
||||
feat_seq (torch.Tensor): Tensor containing feature sequence
|
||||
Shape: (batch_size, feat_dim, feat frame count)
|
||||
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
|
||||
Shape: (batch_size,)
|
||||
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
|
||||
Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
chunk_idx (int): Index of the current chunk
|
||||
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
|
||||
Shape: (batch_size, diar frame count, feat_dim)
|
||||
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
|
||||
Shape: (batch_size,)
|
||||
"""
|
||||
feat_len = feat_seq.shape[2]
|
||||
num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor))
|
||||
if False:
|
||||
logging.info(
|
||||
f"feat_len={feat_len}, num_chunks={num_chunks}, "
|
||||
f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}"
|
||||
)
|
||||
|
||||
stt_feat, end_feat, chunk_idx = 0, 0, 0
|
||||
while end_feat < feat_len:
|
||||
left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat)
|
||||
end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len)
|
||||
right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat)
|
||||
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
|
||||
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
|
||||
0, chunk_feat_seq.shape[2]
|
||||
)
|
||||
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
|
||||
stt_feat = end_feat
|
||||
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
|
||||
if False:
|
||||
logging.info(
|
||||
f"chunk_idx: {chunk_idx}, "
|
||||
f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
|
||||
f"chunk_feat_lengths: {feat_lengths}"
|
||||
)
|
||||
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
|
||||
chunk_idx += 1
|
||||
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
spkcache = None # Speaker cache to store embeddings from start
|
||||
spkcache_lengths = None #
|
||||
spkcache_preds = None # speaker cache predictions
|
||||
fifo = None # to save the embedding from the latest chunks
|
||||
fifo_lengths = None
|
||||
fifo_preds = None
|
||||
spk_perm = None
|
||||
mean_sil_emb = None
|
||||
n_sil_frames = None
|
||||
|
||||
|
||||
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||
"""
|
||||
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size for tensors in streaming state
|
||||
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||
device (torch.device): Device for tensors in streaming state
|
||||
|
||||
Returns:
|
||||
streaming_state (SortformerStreamingState): initialized streaming state
|
||||
"""
|
||||
streaming_state = StreamingSortformerState()
|
||||
if async_streaming:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
else:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
return streaming_state
|
||||
|
||||
def process_diarization(signal, chunks):
|
||||
|
||||
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
|
||||
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128).get_features(audio_signal, audio_signal_length)
|
||||
|
||||
|
||||
streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset)
|
||||
|
||||
|
||||
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||
batch_size = batch_size,
|
||||
async_streaming = True,
|
||||
device = diar_model.device
|
||||
)
|
||||
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||
|
||||
|
||||
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||
print(f"Chunk duration: {chunk_duration_seconds} seconds")
|
||||
|
||||
l_speakers = [
|
||||
{'start_time': 0,
|
||||
'end_time': 0,
|
||||
'speaker': 0
|
||||
}
|
||||
]
|
||||
len_prediction = None
|
||||
left_offset = 0
|
||||
right_offset = 8
|
||||
for i, chunk_feat_seq_t, _, _, _ in streaming_loader:
|
||||
with torch.inference_mode():
|
||||
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=streaming_state,
|
||||
total_preds=total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
left_offset = 8
|
||||
preds_np = total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
if len_prediction is None:
|
||||
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||
frame_duration = chunk_duration_seconds / len_prediction
|
||||
active_speakers = active_speakers[-len_prediction:]
|
||||
print(chunk_feat_seq_t.shape, total_preds.shape)
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
if spk != l_speakers[-1]['speaker']:
|
||||
l_speakers.append(
|
||||
{'start_time': i * chunk_duration_seconds + idx * frame_duration,
|
||||
'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration,
|
||||
'speaker': spk
|
||||
})
|
||||
else:
|
||||
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||
|
||||
print(l_speakers)
|
||||
"""
|
||||
Should print
|
||||
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
import librosa
|
||||
an4_audio = 'new_audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio,sr=16000)
|
||||
|
||||
"""
|
||||
ground truth:
|
||||
speaker 0 : 0:00 - 0:09
|
||||
speaker 1 : 0:09 - 0:19
|
||||
speaker 2 : 0:19 - 0:25
|
||||
speaker 0 : 0:25 - end
|
||||
"""
|
||||
|
||||
# Simulate streaming
|
||||
chunk_size = 16000 # 1 second
|
||||
chunks = []
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
|
||||
process_diarization(signal, chunks)
|
||||
205
whisperlivekit/diarization/sortformer_backend_offline.py
Normal file
205
whisperlivekit/diarization/sortformer_backend_offline.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
import librosa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model():
|
||||
|
||||
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
diar_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
diar_model.to(torch.device("cuda"))
|
||||
|
||||
#we target 1 second lag for the moment. chunk_len could be reduced.
|
||||
diar_model.sortformer_modules.chunk_len = 10
|
||||
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
|
||||
|
||||
diar_model.sortformer_modules.chunk_right_context = 0 #no.
|
||||
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
|
||||
|
||||
diar_model.sortformer_modules.spkcache_len = 188
|
||||
diar_model.sortformer_modules.fifo_len = 188
|
||||
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
diar_model.sortformer_modules.log = False
|
||||
diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
|
||||
audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size= 0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
|
||||
|
||||
return diar_model, audio2mel
|
||||
|
||||
diar_model, audio2mel = load_model()
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
spkcache = None # Speaker cache to store embeddings from start
|
||||
spkcache_lengths = None #
|
||||
spkcache_preds = None # speaker cache predictions
|
||||
fifo = None # to save the embedding from the latest chunks
|
||||
fifo_lengths = None
|
||||
fifo_preds = None
|
||||
spk_perm = None
|
||||
mean_sil_emb = None
|
||||
n_sil_frames = None
|
||||
|
||||
|
||||
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||
"""
|
||||
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size for tensors in streaming state
|
||||
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||
device (torch.device): Device for tensors in streaming state
|
||||
|
||||
Returns:
|
||||
streaming_state (SortformerStreamingState): initialized streaming state
|
||||
"""
|
||||
streaming_state = StreamingSortformerState()
|
||||
if async_streaming:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
else:
|
||||
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
return streaming_state
|
||||
|
||||
|
||||
def process_diarization(chunks):
|
||||
"""
|
||||
what it does:
|
||||
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
|
||||
2. STFT: Computes the Short-Time Fourier Transform using:
|
||||
- the window of window_size=0.025 --> size of a window : 400 samples
|
||||
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
|
||||
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
|
||||
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
|
||||
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
|
||||
6. Normalization: Skips normalization since `normalize="NA"`
|
||||
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
|
||||
"""
|
||||
previous_chunk = None
|
||||
l_chunk_feat_seq_t = []
|
||||
for chunk in chunks:
|
||||
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
|
||||
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
|
||||
if previous_chunk is not None:
|
||||
to_add = previous_chunk[:, :, -99:]
|
||||
total = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||
else:
|
||||
total = processed_signal_chunk
|
||||
previous_chunk = processed_signal_chunk
|
||||
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
|
||||
|
||||
batch_size = 1
|
||||
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||
batch_size = batch_size,
|
||||
async_streaming = True,
|
||||
device = diar_model.device
|
||||
)
|
||||
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||
|
||||
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||
|
||||
l_speakers = [
|
||||
{'start_time': 0,
|
||||
'end_time': 0,
|
||||
'speaker': 0
|
||||
}
|
||||
]
|
||||
len_prediction = None
|
||||
left_offset = 0
|
||||
right_offset = 8
|
||||
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
|
||||
with torch.inference_mode():
|
||||
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||
streaming_state=streaming_state,
|
||||
total_preds=total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
left_offset = 8
|
||||
preds_np = total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
if len_prediction is None:
|
||||
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||
frame_duration = chunk_duration_seconds / len_prediction
|
||||
active_speakers = active_speakers[-len_prediction:]
|
||||
for idx, spk in enumerate(active_speakers):
|
||||
if spk != l_speakers[-1]['speaker']:
|
||||
l_speakers.append(
|
||||
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
|
||||
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
|
||||
'speaker': spk
|
||||
})
|
||||
else:
|
||||
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||
|
||||
|
||||
"""
|
||||
Should print
|
||||
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||
"""
|
||||
for speaker in l_speakers:
|
||||
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
an4_audio = 'audio_test.mp3'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
# signal = signal[:-(len(signal)%16000)]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Expected ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
chunk_size = 16000 # 1 second
|
||||
chunks = []
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
|
||||
process_diarization(chunks)
|
||||
@@ -61,7 +61,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
type=str,
|
||||
default="diart",
|
||||
default="sortformer",
|
||||
choices=["sortformer", "diart"],
|
||||
help="The diarization backend to use.",
|
||||
)
|
||||
|
||||
138
whisperlivekit/results_formater.py
Normal file
138
whisperlivekit/results_formater.py
Normal file
@@ -0,0 +1,138 @@
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?'}
|
||||
CHECK_AROUND = 4
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.text.strip() in PUNCTUATION_MARKS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if is_punctuation(tokens[ind]):
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if is_punctuation(token):
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
speaker,
|
||||
last_end_diarized,
|
||||
debug_info = ""
|
||||
):
|
||||
return {
|
||||
"speaker": int(speaker),
|
||||
"text": token.text + debug_info,
|
||||
"beg": format_time(token.start),
|
||||
"end": format_time(token.end),
|
||||
"diff": round(token.end - last_end_diarized, 2)
|
||||
}
|
||||
|
||||
|
||||
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
|
||||
if token.text:
|
||||
lines[-1]["text"] += sep + token.text + debug_info
|
||||
lines[-1]["end"] = format_time(token.end)
|
||||
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
||||
|
||||
|
||||
def format_output(state, silence, current_time, diarization, debug):
|
||||
tokens = state["tokens"]
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
sep = state["sep"]
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
last_end_diarized = 0
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
|
||||
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
speaker = 1
|
||||
if diarization and not tokens[-1].speaker == -2:
|
||||
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
||||
undiarized_text.append(token.text)
|
||||
continue
|
||||
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
||||
speaker = previous_speaker
|
||||
if speaker not in [-1, 0]:
|
||||
last_end_diarized = max(token.end, last_end_diarized)
|
||||
|
||||
debug_info = ""
|
||||
if debug:
|
||||
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
|
||||
|
||||
if not lines:
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
continue
|
||||
else:
|
||||
previous_speaker = lines[-1]['speaker']
|
||||
|
||||
if is_punctuation(token):
|
||||
last_punctuation = i
|
||||
|
||||
|
||||
if last_punctuation == i-1:
|
||||
if speaker != previous_speaker:
|
||||
# perfect, diarization perfectly aligned
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
last_punctuation, next_punctuation = None, None
|
||||
continue
|
||||
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
|
||||
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
|
||||
else:
|
||||
# No speaker change to come
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
continue
|
||||
|
||||
|
||||
if speaker != previous_speaker:
|
||||
if speaker == -2 or previous_speaker == -2: #silences can happen anytime
|
||||
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
continue
|
||||
elif next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
continue
|
||||
else: #we create a new speaker, but that's no ideal. We are not sure about the split. We prefer to append to previous line
|
||||
# lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
|
||||
pass
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
|
||||
@@ -400,7 +400,12 @@ async function startRecording() {
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
if (window.location.hostname === "0.0.0.0") {
|
||||
statusText.textContent =
|
||||
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||
} else {
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
}
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user