Improve diarization backends

This commit is contained in:
Quentin Fuxa
2026-02-15 14:55:00 +01:00
parent 32de7b1276
commit e30f9a2573
2 changed files with 78 additions and 81 deletions

View File

@@ -20,25 +20,25 @@ logger = logging.getLogger(__name__)
class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self):
self.diarization_segments = []
self.processed_time = 0
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value
logger.debug("\n--- New Diarization Result ---")
duration = audio.extent.end - audio.extent.start
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
logger.debug(f"Audio shape: {audio.data.shape}")
with self.segment_lock:
if audio.extent.end > self.processed_time:
self.processed_time = audio.extent.end
self.processed_time = audio.extent.end
if annotation and len(annotation._labels) > 0:
logger.debug("\nSpeaker segments:")
for speaker, label in annotation._labels.items():
@@ -51,25 +51,25 @@ class DiarizationObserver(Observer):
))
else:
logger.debug("\nNo speakers detected in this segment")
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.diarization_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.diarization_segments = [
segment for segment in self.diarization_segments
segment for segment in self.diarization_segments
if current_time - segment.end < older_than
]
def on_error(self, error):
"""Handle an error in the stream."""
logger.debug(f"Error in diarization stream: {error}")
def on_completed(self):
"""Handle the completion of the stream."""
logger.debug("Diarization stream completed")
@@ -96,7 +96,7 @@ class WebSocketAudioSource(AudioSource):
self._processing_thread = threading.Thread(target=self._process_chunks)
self._processing_thread.daemon = True
self._processing_thread.start()
self._close_event.wait()
if self._processing_thread:
self._processing_thread.join(timeout=2.0)
@@ -106,30 +106,30 @@ class WebSocketAudioSource(AudioSource):
while not self._closed:
try:
audio_chunk = self._queue.get(timeout=0.1)
with self._buffer_lock:
self._buffer = np.concatenate([self._buffer, audio_chunk])
while len(self._buffer) >= self.block_size:
chunk = self._buffer[:self.block_size]
self._buffer = self._buffer[self.block_size:]
current_time = time.time()
time_since_last = current_time - self._last_chunk_time
if time_since_last < self.block_duration:
time.sleep(self.block_duration - time_since_last)
chunk_reshaped = chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
except Empty:
with self._buffer_lock:
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
self._buffer = np.array([], dtype=np.float32)
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self._last_chunk_time = time.time()
@@ -137,14 +137,14 @@ class WebSocketAudioSource(AudioSource):
logger.error(f"Error in audio processing thread: {e}")
self.stream.on_error(e)
break
with self._buffer_lock:
if len(self._buffer) > 0:
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
padded_chunk[:len(self._buffer)] = self._buffer
chunk_reshaped = padded_chunk.reshape(1, -1)
self.stream.on_next(chunk_reshaped)
self.stream.on_completed()
def close(self):
@@ -165,27 +165,27 @@ class DiartDiarization:
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
if config is None:
config = SpeakerDiarizationConfig(
segmentation=segmentation_model,
embedding=embedding_model,
)
self.pipeline = SpeakerDiarization(config=config)
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration)
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(
uri="websocket_source",
uri="websocket_source",
sample_rate=sample_rate,
block_duration=block_duration
)
self.source = self.custom_source
self.inference = StreamingInference(
pipeline=self.pipeline,
source=self.source,
@@ -205,14 +205,14 @@ class DiartDiarization:
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
return self.observer.get_segments()
def close(self):
"""Close the audio source."""
if self.custom_source:
self.custom_source.close()
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
for segment in segments:
@@ -223,7 +223,7 @@ def concatenate_speakers(segments):
segments_concatenated[-1]['end'] = segment.end
# print("Segments concatenated:")
# for entry in segments_concatenated:
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
return segments_concatenated
@@ -281,4 +281,4 @@ def visualize_tokens(tokens):
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")
print(f"Speaker {entry['speaker']}: {entry['text']}")

View File

@@ -1,8 +1,6 @@
import logging
import threading
import time
import wave
from queue import Empty, SimpleQueue
from typing import List, Optional
import numpy as np
@@ -54,7 +52,7 @@ class SortformerDiarization:
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
"""
self._load_model(model_name)
def _load_model(self, model_name: str):
"""Load and configure the Sortformer model for streaming."""
try:
@@ -63,12 +61,12 @@ class SortformerDiarization:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.diar_model.to(device)
## to test
# for name, param in self.diar_model.named_parameters():
# if param.device != device:
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
logger.info(f"Using {device.type.upper()} for Sortformer model")
self.diar_model.sortformer_modules.chunk_len = 10
@@ -80,16 +78,16 @@ class SortformerDiarization:
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
except Exception as e:
logger.error(f"Failed to load Sortformer model: {e}")
raise
class SortformerDiarizationOnline:
def __init__(self, shared_model, sample_rate: int = 16000):
"""
Initialize the streaming Sortformer diarization system.
Args:
sample_rate: Audio sample rate (default: 16000)
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
@@ -101,9 +99,9 @@ class SortformerDiarizationOnline:
self.segment_lock = threading.Lock()
self.global_time_offset = 0.0
self.debug = False
self.diar_model = shared_model.diar_model
self.audio2mel = AudioToMelSpectrogramPreprocessor(
window_size=0.025,
normalize="NA",
@@ -112,26 +110,26 @@ class SortformerDiarizationOnline:
pad_to=0
)
self.audio2mel.to(self.diar_model.device)
self.chunk_duration_seconds = (
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.sortformer_modules.chunk_len *
self.diar_model.sortformer_modules.subsampling_factor *
self.diar_model.preprocessor._cfg.window_stride
)
self._init_streaming_state()
self._previous_chunk_features = None
self._chunk_index = 0
self._len_prediction = None
# Audio buffer to store PCM chunks for debugging
self.audio_buffer = []
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
self.audio_chunk_buffer = []
self.accumulated_duration = 0.0
logger.info("SortformerDiarization initialized successfully")
@@ -139,30 +137,30 @@ class SortformerDiarizationOnline:
"""Initialize the streaming state for the model."""
batch_size = 1
device = self.diar_model.device
self.streaming_state = StreamingSortformerState()
self.streaming_state.spkcache = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.spkcache_preds = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
device=device
)
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.fifo = torch.zeros(
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
device=device
)
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
def insert_silence(self, silence_duration: Optional[float]):
"""
Insert silence period by adjusting the global time offset.
Args:
silence_duration: Duration of silence in seconds
"""
@@ -174,48 +172,48 @@ class SortformerDiarizationOnline:
if self.debug:
self.audio_buffer.append(pcm_array.copy())
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
async def diarize(self):
"""
Process audio data for diarization in streaming fashion.
Args:
pcm_array: Audio data as numpy array
"""
threshold = int(self.chunk_duration_seconds * self.sample_rate)
if not len(self.buffer_audio) >= threshold:
return []
audio = self.buffer_audio[:threshold]
self.buffer_audio = self.buffer_audio[threshold:]
device = self.diar_model.device
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
audio_signal_chunk, audio_signal_length_chunk
)
processed_signal_chunk = processed_signal_chunk.to(device)
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
if self._previous_chunk_features is not None:
to_add = self._previous_chunk_features[:, :, -99:].to(device)
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
else:
total_features = processed_signal_chunk.to(device)
self._previous_chunk_features = processed_signal_chunk.to(device)
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
with torch.inference_mode():
left_offset = 8 if self._chunk_index > 0 else 0
right_offset = 8
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
@@ -223,9 +221,9 @@ class SortformerDiarizationOnline:
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
)
new_segments = self._process_predictions()
self._chunk_index += 1
return new_segments
@@ -233,13 +231,13 @@ class SortformerDiarizationOnline:
"""Process model predictions and convert to speaker segments."""
preds_np = self.total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if self._len_prediction is None:
self._len_prediction = len(active_speakers) #12
frame_duration = self.chunk_duration_seconds / self._len_prediction
current_chunk_preds = active_speakers[-self._len_prediction:]
new_segments = []
with self.segment_lock:
@@ -264,7 +262,7 @@ class SortformerDiarizationOnline:
)
)
return new_segments
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
@@ -275,10 +273,10 @@ class SortformerDiarizationOnline:
logger.info("Closing SortformerDiarization")
with self.segment_lock:
self.diarization_segments.clear()
if self.debug:
concatenated_audio = np.concatenate(self.audio_buffer)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
with wave.open("diarization_audio.wav", "wb") as wav_file:
wav_file.setnchannels(1) # mono audio
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
@@ -287,14 +285,13 @@ class SortformerDiarizationOnline:
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
from whisperlivekit.diarization.utils import extract_number
if __name__ == '__main__':
import asyncio
import librosa
async def main():
"""TEST ONLY."""
an4_audio = 'diarization_audio.wav'
@@ -304,24 +301,24 @@ if __name__ == '__main__':
print("\n" + "=" * 50)
print("ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
diarization_backend = SortformerDiarization()
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
chunk_size = 1600
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
new_segments = await diarization.diarize(chunk)
print(f"Processed chunk {i // chunk_size + 1}")
print(new_segments)
segments = diarization.get_segments()
print("\nDiarization results:")
for segment in segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
asyncio.run(main())