mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
458 lines
19 KiB
Python
458 lines
19 KiB
Python
import numpy as np
|
|
import torch
|
|
import logging
|
|
import threading
|
|
import time
|
|
import wave
|
|
from typing import List, Optional
|
|
from queue import SimpleQueue, Empty
|
|
|
|
from whisperlivekit.timed_objects import SpeakerSegment
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
|
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
|
except ImportError:
|
|
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
|
|
|
|
|
class StreamingSortformerState:
|
|
"""
|
|
This class creates a class instance that will be used to store the state of the
|
|
streaming Sortformer model.
|
|
|
|
Attributes:
|
|
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
|
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
|
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
|
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
|
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
|
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
|
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
|
mean_sil_emb (torch.Tensor): Mean silence embedding
|
|
n_sil_frames (torch.Tensor): Number of silence frames
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.spkcache = None # Speaker cache to store embeddings from start
|
|
self.spkcache_lengths = None
|
|
self.spkcache_preds = None # speaker cache predictions
|
|
self.fifo = None # to save the embedding from the latest chunks
|
|
self.fifo_lengths = None
|
|
self.fifo_preds = None
|
|
self.spk_perm = None
|
|
self.mean_sil_emb = None
|
|
self.n_sil_frames = None
|
|
|
|
|
|
class SortformerDiarization:
|
|
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
|
"""
|
|
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
|
"""
|
|
self._load_model(model_name)
|
|
|
|
def _load_model(self, model_name: str):
|
|
"""Load and configure the Sortformer model for streaming."""
|
|
try:
|
|
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
|
self.diar_model.eval()
|
|
|
|
if torch.cuda.is_available():
|
|
self.diar_model.to(torch.device("cuda"))
|
|
logger.info("Using CUDA for Sortformer model")
|
|
else:
|
|
logger.info("Using CPU for Sortformer model")
|
|
|
|
self.diar_model.sortformer_modules.chunk_len = 10
|
|
self.diar_model.sortformer_modules.subsampling_factor = 10
|
|
self.diar_model.sortformer_modules.chunk_right_context = 0
|
|
self.diar_model.sortformer_modules.chunk_left_context = 10
|
|
self.diar_model.sortformer_modules.spkcache_len = 188
|
|
self.diar_model.sortformer_modules.fifo_len = 188
|
|
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
|
self.diar_model.sortformer_modules.log = False
|
|
self.diar_model.sortformer_modules._check_streaming_parameters()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Sortformer model: {e}")
|
|
raise
|
|
|
|
class SortformerDiarizationOnline:
|
|
def __init__(self, shared_model, sample_rate: int = 16000):
|
|
"""
|
|
Initialize the streaming Sortformer diarization system.
|
|
|
|
Args:
|
|
sample_rate: Audio sample rate (default: 16000)
|
|
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
|
"""
|
|
self.sample_rate = sample_rate
|
|
self.speaker_segments = []
|
|
self.buffer_audio = np.array([], dtype=np.float32)
|
|
self.segment_lock = threading.Lock()
|
|
self.global_time_offset = 0.0
|
|
self.processed_time = 0.0
|
|
self.debug = False
|
|
|
|
self.diar_model = shared_model.diar_model
|
|
|
|
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
|
window_size=0.025,
|
|
normalize="NA",
|
|
n_fft=512,
|
|
features=128,
|
|
pad_to=0
|
|
)
|
|
|
|
self.chunk_duration_seconds = (
|
|
self.diar_model.sortformer_modules.chunk_len *
|
|
self.diar_model.sortformer_modules.subsampling_factor *
|
|
self.diar_model.preprocessor._cfg.window_stride
|
|
)
|
|
|
|
self._init_streaming_state()
|
|
|
|
self._previous_chunk_features = None
|
|
self._chunk_index = 0
|
|
self._len_prediction = None
|
|
|
|
# Audio buffer to store PCM chunks for debugging
|
|
self.audio_buffer = []
|
|
|
|
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
|
self.audio_chunk_buffer = []
|
|
self.accumulated_duration = 0.0
|
|
|
|
logger.info("SortformerDiarization initialized successfully")
|
|
|
|
|
|
def _init_streaming_state(self):
|
|
"""Initialize the streaming state for the model."""
|
|
batch_size = 1
|
|
device = self.diar_model.device
|
|
|
|
self.streaming_state = StreamingSortformerState()
|
|
self.streaming_state.spkcache = torch.zeros(
|
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
|
device=device
|
|
)
|
|
self.streaming_state.spkcache_preds = torch.zeros(
|
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
|
device=device
|
|
)
|
|
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
|
self.streaming_state.fifo = torch.zeros(
|
|
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
|
device=device
|
|
)
|
|
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
|
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
|
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
|
|
|
# Initialize total predictions tensor
|
|
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
|
|
|
def insert_silence(self, silence_duration: float):
|
|
"""
|
|
Insert silence period by adjusting the global time offset.
|
|
|
|
Args:
|
|
silence_duration: Duration of silence in seconds
|
|
"""
|
|
with self.segment_lock:
|
|
self.global_time_offset += silence_duration
|
|
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
|
|
|
async def diarize(self, pcm_array: np.ndarray):
|
|
"""
|
|
Process audio data for diarization in streaming fashion.
|
|
|
|
Args:
|
|
pcm_array: Audio data as numpy array
|
|
"""
|
|
try:
|
|
if self.debug:
|
|
self.audio_buffer.append(pcm_array.copy())
|
|
|
|
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
|
|
|
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
|
if not len(self.buffer_audio) >= threshold:
|
|
return
|
|
|
|
audio = self.buffer_audio[:threshold]
|
|
self.buffer_audio = self.buffer_audio[threshold:]
|
|
|
|
audio_signal_chunk = torch.tensor(audio).unsqueeze(0).to(self.diar_model.device)
|
|
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(self.diar_model.device)
|
|
|
|
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
|
audio_signal_chunk, audio_signal_length_chunk
|
|
)
|
|
|
|
if self._previous_chunk_features is not None:
|
|
to_add = self._previous_chunk_features[:, :, -99:]
|
|
total_features = torch.concat([to_add, processed_signal_chunk], dim=2)
|
|
else:
|
|
total_features = processed_signal_chunk
|
|
|
|
self._previous_chunk_features = processed_signal_chunk
|
|
|
|
chunk_feat_seq_t = torch.transpose(total_features, 1, 2)
|
|
|
|
with torch.inference_mode():
|
|
left_offset = 8 if self._chunk_index > 0 else 0
|
|
right_offset = 8
|
|
|
|
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
|
processed_signal=chunk_feat_seq_t,
|
|
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
|
streaming_state=self.streaming_state,
|
|
total_preds=self.total_preds,
|
|
left_offset=left_offset,
|
|
right_offset=right_offset,
|
|
)
|
|
|
|
# Convert predictions to speaker segments
|
|
self._process_predictions()
|
|
|
|
self._chunk_index += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in diarize: {e}")
|
|
raise
|
|
|
|
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
|
|
|
def _process_predictions(self):
|
|
"""Process model predictions and convert to speaker segments."""
|
|
try:
|
|
preds_np = self.total_preds[0].cpu().numpy()
|
|
active_speakers = np.argmax(preds_np, axis=1)
|
|
|
|
if self._len_prediction is None:
|
|
self._len_prediction = len(active_speakers)
|
|
|
|
# Get predictions for current chunk
|
|
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
|
current_chunk_preds = active_speakers[-self._len_prediction:]
|
|
|
|
with self.segment_lock:
|
|
# Process predictions into segments
|
|
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
|
|
|
for idx, spk in enumerate(current_chunk_preds):
|
|
start_time = base_time + idx * frame_duration
|
|
end_time = base_time + (idx + 1) * frame_duration
|
|
|
|
# Check if this continues the last segment or starts a new one
|
|
if (self.speaker_segments and
|
|
self.speaker_segments[-1].speaker == spk and
|
|
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
|
# Continue existing segment
|
|
self.speaker_segments[-1].end = end_time
|
|
else:
|
|
|
|
# Create new segment
|
|
self.speaker_segments.append(SpeakerSegment(
|
|
speaker=spk,
|
|
start=start_time,
|
|
end=end_time
|
|
))
|
|
|
|
# Update processed time
|
|
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
|
|
|
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing predictions: {e}")
|
|
|
|
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
|
"""
|
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
|
|
|
Args:
|
|
tokens: List of tokens with timing information
|
|
use_punctuation_split: Whether to use punctuation for boundary refinement
|
|
|
|
Returns:
|
|
List of tokens with speaker assignments
|
|
"""
|
|
with self.segment_lock:
|
|
segments = self.speaker_segments.copy()
|
|
|
|
if not segments or not tokens:
|
|
logger.debug("No segments or tokens available for speaker assignment")
|
|
return tokens
|
|
|
|
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
|
use_punctuation_split = False
|
|
if not use_punctuation_split:
|
|
# Simple overlap-based assignment
|
|
for token in tokens:
|
|
token.speaker = -1 # Default to no speaker
|
|
for segment in segments:
|
|
# Check for timing overlap
|
|
if not (segment.end <= token.start or segment.start >= token.end):
|
|
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
|
break
|
|
else:
|
|
# Use punctuation-aware assignment (similar to diart_backend)
|
|
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
|
|
|
return tokens
|
|
|
|
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
|
"""
|
|
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
|
|
|
Args:
|
|
segments: List of speaker segments
|
|
tokens: List of tokens to assign speakers to
|
|
|
|
Returns:
|
|
List of tokens with speaker assignments
|
|
"""
|
|
punctuation_marks = {'.', '!', '?'}
|
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
|
|
|
# Convert segments to concatenated format
|
|
segments_concatenated = self._concatenate_speakers(segments)
|
|
|
|
# Adjust segment boundaries based on punctuation
|
|
for ind, segment in enumerate(segments_concatenated):
|
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
|
if punctuation_token.start > segment['end']:
|
|
after_length = punctuation_token.start - segment['end']
|
|
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
|
|
|
if before_length > after_length:
|
|
segment['end'] = punctuation_token.start
|
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
|
else:
|
|
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
|
break
|
|
|
|
# Ensure non-overlapping tokens
|
|
last_end = 0.0
|
|
for token in tokens:
|
|
start = max(last_end + 0.01, token.start)
|
|
token.start = start
|
|
token.end = max(start, token.end)
|
|
last_end = token.end
|
|
|
|
# Assign speakers based on adjusted segments
|
|
ind_last_speaker = 0
|
|
for segment in segments_concatenated:
|
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
|
if token.end <= segment['end']:
|
|
token.speaker = segment['speaker']
|
|
ind_last_speaker = i + 1
|
|
elif token.start > segment['end']:
|
|
break
|
|
|
|
return tokens
|
|
|
|
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
|
"""
|
|
Concatenate consecutive segments from the same speaker.
|
|
|
|
Args:
|
|
segments: List of speaker segments
|
|
|
|
Returns:
|
|
List of concatenated speaker segments
|
|
"""
|
|
if not segments:
|
|
return []
|
|
|
|
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
|
|
|
for segment in segments[1:]:
|
|
speaker = segment.speaker + 1
|
|
if segments_concatenated[-1]['speaker'] != speaker:
|
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
|
else:
|
|
segments_concatenated[-1]['end'] = segment.end
|
|
|
|
return segments_concatenated
|
|
|
|
def get_segments(self) -> List[SpeakerSegment]:
|
|
"""Get a copy of the current speaker segments."""
|
|
with self.segment_lock:
|
|
return self.speaker_segments.copy()
|
|
|
|
def clear_old_segments(self, older_than: float = 30.0):
|
|
"""Clear segments older than the specified time."""
|
|
with self.segment_lock:
|
|
current_time = self.processed_time
|
|
self.speaker_segments = [
|
|
segment for segment in self.speaker_segments
|
|
if current_time - segment.end < older_than
|
|
]
|
|
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
|
|
|
def close(self):
|
|
"""Close the diarization system and clean up resources."""
|
|
logger.info("Closing SortformerDiarization")
|
|
with self.segment_lock:
|
|
self.speaker_segments.clear()
|
|
|
|
if self.debug:
|
|
concatenated_audio = np.concatenate(self.audio_buffer)
|
|
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
|
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
|
wav_file.setnchannels(1) # mono audio
|
|
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
|
wav_file.setframerate(self.sample_rate)
|
|
wav_file.writeframes(audio_data_int16.tobytes())
|
|
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
|
|
|
|
|
def extract_number(s: str) -> int:
|
|
"""Extract number from speaker string (compatibility function)."""
|
|
import re
|
|
m = re.search(r'\d+', s)
|
|
return int(m.group()) if m else 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import asyncio
|
|
import librosa
|
|
|
|
async def main():
|
|
"""TEST ONLY."""
|
|
an4_audio = 'audio_test.mp3'
|
|
signal, sr = librosa.load(an4_audio, sr=16000)
|
|
signal = signal[:16000*30]
|
|
|
|
print("\n" + "=" * 50)
|
|
print("ground truth:")
|
|
print("Speaker 0: 0:00 - 0:09")
|
|
print("Speaker 1: 0:09 - 0:19")
|
|
print("Speaker 2: 0:19 - 0:25")
|
|
print("Speaker 0: 0:25 - 0:30")
|
|
print("=" * 50)
|
|
|
|
diarization = SortformerDiarization(sample_rate=16000)
|
|
chunk_size = 1600
|
|
|
|
for i in range(0, len(signal), chunk_size):
|
|
chunk = signal[i:i+chunk_size]
|
|
await diarization.diarize(chunk)
|
|
print(f"Processed chunk {i // chunk_size + 1}")
|
|
|
|
segments = diarization.get_segments()
|
|
print("\nDiarization results:")
|
|
for segment in segments:
|
|
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
|
|
|
asyncio.run(main())
|