sortformer diar implementation v0.3

This commit is contained in:
Quentin Fuxa
2025-08-24 18:32:01 +02:00
parent 3393a08f7e
commit 58297daf6d
5 changed files with 492 additions and 8 deletions

View File

@@ -66,7 +66,8 @@ pip install whisperlivekit
| Optional | `pip install` |
|-----------|-------------|
| Speaker diarization | `whisperlivekit[diarization]` |
| Speaker diarization with Sortformer | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` |
| Original Whisper backend | `whisperlivekit[whisper]` |
| Improved timestamps backend | `whisperlivekit[whisper-timestamped]` |
| Apple Silicon optimization backend | `whisperlivekit[mlx-whisper]` |
@@ -185,9 +186,10 @@ The package includes an HTML/JavaScript implementation [here](https://github.com
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` |
| `--diarization-backend` | `diart` or `sortformer` | `diart` |
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
### 🚀 Deployment Guide

View File

@@ -303,7 +303,7 @@ class AudioProcessor:
if type(item) is Silence:
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
if self.tokens:
asr_processing_logs += " | last_end = {self.tokens[-1].end} |"
asr_processing_logs += f" | last_end = {self.tokens[-1].end} |"
logger.info(asr_processing_logs)
if type(item) is Silence:
@@ -469,7 +469,7 @@ class AudioProcessor:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
lines.append({
"speaker": speaker,
"speaker": str(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),

View File

@@ -52,7 +52,7 @@ async def handle_websocket_results(websocket, results_generator):
except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e:
logger.error(f"Error in WebSocket results handler: {e}")
logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr")

View File

@@ -127,7 +127,8 @@ class TranscriptionEngine:
embedding_model_name=self.args.embedding_model
)
elif self.args.diarization_backend == "sortformer":
raise ValueError('Sortformer backend in developement')
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
self.diarization = SortformerDiarization()
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")

View File

@@ -1,11 +1,492 @@
import numpy as np
import torch
import logging
import threading
import time
import wave
from typing import List, Optional
from queue import SimpleQueue, Empty
from whisperlivekit.timed_objects import SpeakerSegment
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
def __init__(self):
self.spkcache = None # Speaker cache to store embeddings from start
self.spkcache_lengths = None
self.spkcache_preds = None # speaker cache predictions
self.fifo = None # to save the embedding from the latest chunks
self.fifo_lengths = None
self.fifo_preds = None
self.spk_perm = None
self.mean_sil_emb = None
self.n_sil_frames = None
class SortformerDiarization:
def __init__(self, sample_rate: int = 16000, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
"""
Initialize the streaming Sortformer diarization system.
Args:
sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
"""
self.sample_rate = sample_rate
self.speaker_segments = []
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.processed_time = 0.0
# Load and configure the model
self._load_model(model_name)
# Initialize streaming state
self._init_streaming_state()
# Audio processing variables
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda"))
logger.info("Using CUDA for Sortformer model")
else:
logger.info("Using CPU for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10
self.diar_model.sortformer_modules.subsampling_factor = 10
self.diar_model.sortformer_modules.chunk_right_context = 0
self.diar_model.sortformer_modules.chunk_left_context = 10
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0
)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride
)
logger.info(f"Chunk duration: {self.chunk_duration_seconds:.2f}s")
except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
def _init_streaming_state(self):
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
# Initialize total predictions tensor
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: float):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
with self.segment_lock:
self.global_time_offset += silence_duration
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
"""
try:
# Store PCM array for debugging
self.audio_buffer.append(pcm_array.copy())
# Add to buffer and accumulate duration
self.audio_chunk_buffer.append(pcm_array.copy())
chunk_duration = len(pcm_array) / self.sample_rate
self.accumulated_duration += chunk_duration
# Check if we have accumulated enough audio
if self.accumulated_duration < self.chunk_duration_seconds:
print(f"Accumulating audio: {self.accumulated_duration:.2f}/{self.chunk_duration_seconds:.2f}s")
return
# Concatenate all buffered audio chunks
concatenated_audio = np.concatenate(self.audio_chunk_buffer)
# Reset buffer and accumulated duration
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
# Convert audio to torch tensor
audio_signal_chunk = torch.tensor(concatenated_audio).unsqueeze(0).to(self.diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
# Extract mel features
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
# Handle feature overlap for continuity
if self._previous_chunk_features is not None:
# Add overlap from previous chunk (99 frames as in offline version)
to_add = self._previous_chunk_features[:, :, -99:]
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total_features = processed_signal_chunk
# Store current features for next iteration
self._previous_chunk_features = processed_signal_chunk
# Transpose for model input
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
# Process with streaming model
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
# Convert predictions to speaker segments
self._process_predictions()
self._chunk_index += 1
except Exception as e:
logger.error(f"Error in diarize: {e}")
raise
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
def _process_predictions(self):
"""Process model predictions and convert to speaker segments."""
try:
preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers)
# Get predictions for current chunk
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
with self.segment_lock:
# Process predictions into segments
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
for idx, spk in enumerate(current_chunk_preds):
start_time = base_time + idx * frame_duration
end_time = base_time + (idx + 1) * frame_duration
# Check if this continues the last segment or starts a new one
if (self.speaker_segments and
self.speaker_segments[-1].speaker == spk and
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
# Continue existing segment
self.speaker_segments[-1].end = end_time
else:
# Create new segment
self.speaker_segments.append(SpeakerSegment(
speaker=spk,
start=start_time,
end=end_time
))
print('NEW SPEAKER, SpeakerSegment:', str(self.speaker_segments[-1]))
# Update processed time
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
with self.segment_lock:
segments = self.speaker_segments.copy()
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.speaker_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
def close(self):
"""Close the diarization system and clean up resources."""
logger.info("Closing SortformerDiarization")
with self.segment_lock:
self.speaker_segments.clear()
# Save audio buffer to WAV file for debugging
if self.audio_buffer:
try:
# Concatenate all PCM chunks
concatenated_audio = np.concatenate(self.audio_buffer)
# Convert from float32 back to int16 for WAV file
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
# Write to WAV file
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # Mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data_int16.tobytes())
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
except Exception as e:
logger.error(f"Error saving audio to WAV file: {e}")
else:
logger.info("No audio data to save")
def extract_number(s: str) -> int:
"""Extract number from speaker string (compatibility function)."""
import re
m = re.search(r'\d+', s)
return int(m.group()) if m else 0
if __name__ == '__main__':
import asyncio
import librosa
async def main():
"""Main function to test SortformerDiarization with the same example as offline version."""
# Load and prepare audio (same as offline version)
try:
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30] # 30 seconds
except Exception as e:
print(f"Error loading audio file: {e}")
return
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
# Create diarization instance
diarization = SortformerDiarization(sample_rate=16000)
# Chunk and process audio
chunk_size = 16000 # 1 second
print(f"Processing {len(signal)} samples in {len(signal) // chunk_size} chunks...")
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
# Close and save WAV
# diarization.close()
# Print results
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())