diff --git a/architecture.png b/architecture.png index b70d06e..3797148 100644 Binary files a/architecture.png and b/architecture.png differ diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index c3e414a..7d01bc2 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -8,9 +8,9 @@ from datetime import timedelta from whisperlivekit.timed_objects import ASRToken, Silence from whisperlivekit.core import TranscriptionEngine, online_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState -from .remove_silences import handle_silences -from trail_repetition import trim_tail_repetition -from silero_vad_iterator import FixedVADIterator +from whisperlivekit.remove_silences import handle_silences +from whisperlivekit.trail_repetition import trim_tail_repetition +from whisperlivekit.silero_vad_iterator import FixedVADIterator # Set up logging once logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -228,9 +228,6 @@ class AudioProcessor: if self.args.vac: res = self.vac(pcm_array) - if self.silence: - print('NO AUDIO') - if res is not None: if res.get('end', 0) > res.get('start', 0): end_of_audio = True @@ -364,15 +361,25 @@ class AudioProcessor: async def diarization_processor(self, diarization_obj): """Process audio chunks for speaker diarization.""" buffer_diarization = "" - + cumulative_pcm_duration_stream_time = 0.0 while True: try: - pcm_array = await self.diarization_queue.get() - if pcm_array is SENTINEL: + item = await self.diarization_queue.get() + if item is SENTINEL: logger.debug("Diarization processor received sentinel. Finishing.") self.diarization_queue.task_done() break + if type(item) is Silence: + cumulative_pcm_duration_stream_time += item.duration + # self.diarization_obj.insert_silence(item.duration, self.tokens[-1].end) + continue + + if isinstance(item, np.ndarray): + pcm_array = item + else: + raise Exception('item should be pcm_array') + # Process diarization await diarization_obj.diarize(pcm_array) diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index afe62ab..9ce0a1e 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -47,7 +47,7 @@ async def handle_websocket_results(websocket, results_generator): except WebSocketDisconnect: logger.info("WebSocket disconnected while handling results (client likely closed connection).") except Exception as e: - logger.warning(f"Error in WebSocket results handler: {e}") + logger.error(f"Error in WebSocket results handler: {e}") @app.websocket("/asr") diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 3ca3d41..1231220 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -57,10 +57,10 @@ class TranscriptionEngine: "static_init_prompt": None, "max_context_tokens": None, "model_path": './base.pt', + "diarization_backend": "sortformer", # diart params: "segmentation_model": "pyannote/segmentation-3.0", "embedding_model": "pyannote/embedding", - } config_dict = {**defaults, **kwargs} @@ -119,12 +119,20 @@ class TranscriptionEngine: warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here if self.args.diarization: - from whisperlivekit.diarization.diarization_online import DiartDiarization - self.diarization = DiartDiarization( - block_duration=self.args.min_chunk_size, - segmentation_model_name=self.args.segmentation_model, - embedding_model_name=self.args.embedding_model - ) + if self.args.diarization_backend == "diart": + from whisperlivekit.diarization.diart_backend import DiartDiarization + self.diarization = DiartDiarization( + block_duration=self.args.min_chunk_size, + segmentation_model_name=self.args.segmentation_model, + embedding_model_name=self.args.embedding_model + ) + elif self.args.diarization_backend == "sortformer": + from whisperlivekit.diarization.sortformer_backend import SortformerDiarization + self.diarization = SortformerDiarization( + model_name="nvidia/diar_streaming_sortformer_4spk-v2" + ) + else: + raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}") TranscriptionEngine._initialized = True diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index d506f26..c3298aa 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -1,6 +1,7 @@ import numpy as np import torch import logging +from whisperlivekit.timed_objects import SpeakerSegment logger = logging.getLogger(__name__) @@ -8,110 +9,137 @@ try: from nemo.collections.asr.models import SortformerEncLabelModel except ImportError: raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""") - -diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2") -diar_model.eval() +class SortformerDiarization: + def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"): + self.diar_model = SortformerEncLabelModel.from_pretrained(model_name) + self.diar_model.eval() -if torch.cuda.is_available(): - diar_model.to(torch.device("cuda")) - -# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader. -# diar_model.sortformer_modules.chunk_len = 6 -# diar_model.sortformer_modules.spkcache_len = 188 -# diar_model.sortformer_modules.chunk_right_context = 7 -# diar_model.sortformer_modules.fifo_len = 188 -# diar_model.sortformer_modules.spkcache_update_period = 144 -# diar_model.sortformer_modules.log = False + if torch.cuda.is_available(): + self.diar_model.to(torch.device("cuda")) + + # Streaming parameters for speed + self.diar_model.sortformer_modules.chunk_len = 12 + self.diar_model.sortformer_modules.chunk_right_context = 1 + self.diar_model.sortformer_modules.spkcache_len = 188 + self.diar_model.sortformer_modules.fifo_len = 188 + self.diar_model.sortformer_modules.spkcache_update_period = 144 + self.diar_model.sortformer_modules.log = False + self.diar_model.sortformer_modules._check_streaming_parameters() + + self.batch_size = 1 + self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device) + + self.audio_buffer = np.array([], dtype=np.float32) + self.sample_rate = 16000 + self.speaker_segments = [] + + self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state( + batch_size=self.batch_size, + async_streaming=True, + device=self.diar_model.device + ) + self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device) -# here we change the settings for our goal: speed! -# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12. -diar_model.sortformer_modules.chunk_len = 12 + def _prepare_audio_signal(self, signal): + audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device) + audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device) + processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length) + return processed_signal, processed_signal_length -# for more speed, we reduce the 'right context'. it's like looking less into the future. -diar_model.sortformer_modules.chunk_right_context = 1 + def _create_streaming_loader(self, processed_signal, processed_signal_length): + streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader( + feat_seq=processed_signal, + feat_seq_length=processed_signal_length, + feat_seq_offset=self.processed_signal_offset, + ) + return streaming_loader -# we keep the rest same for now -diar_model.sortformer_modules.spkcache_len = 188 -diar_model.sortformer_modules.fifo_len = 188 -diar_model.sortformer_modules.spkcache_update_period = 144 -diar_model.sortformer_modules.log = False -diar_model.sortformer_modules._check_streaming_parameters() + async def diarize(self, pcm_array: np.ndarray): + """ + Process an incoming audio chunk for diarization. + """ + self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array]) + + # Process in fixed-size chunks (e.g., 1 second) + chunk_size = self.sample_rate # 1 second of audio + + while len(self.audio_buffer) >= chunk_size: + chunk_to_process = self.audio_buffer[:chunk_size] + self.audio_buffer = self.audio_buffer[chunk_size:] -batch_size = 1 -processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device) + processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process) + + current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride -def prepare_audio_signal(signal): - audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device) - audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device) - processed_signal, processed_signal_length = diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length) - return processed_signal, processed_signal_length + streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length) + + frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride + chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s -def create_streaming_loader(processed_signal, processed_signal_length): - streaming_loader = diar_model.sortformer_modules.streaming_feat_loader( - feat_seq=processed_signal, - feat_seq_length=processed_signal_length, - feat_seq_offset=processed_signal_offset, - ) - return streaming_loader - - -def process_diarization(streaming_loader): - - streaming_state = diar_model.sortformer_modules.init_streaming_state( - batch_size = batch_size, - async_streaming = True, - device = diar_model.device -) - total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device) - - - chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride - print(f"Chunk duration: {chunk_duration_seconds} seconds") - - l_speakers = [ - {'start_time': 0, - 'end_time': 0, - 'speaker': 0 - } - ] - len_prediction = None - for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader: - with torch.inference_mode(): - streaming_state, total_preds = diar_model.forward_streaming_step( - processed_signal=chunk_feat_seq_t, - processed_signal_length=feat_lengths, - streaming_state=streaming_state, - total_preds=total_preds, - left_offset=left_offset, - right_offset=right_offset, - ) - preds_np = total_preds[0].cpu().numpy() - active_speakers = np.argmax(preds_np, axis=1) - if len_prediction is None: - len_prediction = len(active_speakers) # we want to get the len of 1 prediction - frame_duration = chunk_duration_seconds / len_prediction - active_speakers = active_speakers[-len_prediction:] - - for idx, spk in enumerate(active_speakers): - if spk != l_speakers[-1]['speaker']: - l_speakers.append( - {'start_time': i * chunk_duration_seconds + idx * frame_duration, - 'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration, - 'speaker': spk - }) - else: - l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration + for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader: + with torch.inference_mode(): + self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( + processed_signal=chunk_feat_seq_t, + processed_signal_length=feat_lengths, + streaming_state=self.streaming_state, + total_preds=self.total_preds, + left_offset=left_offset, + right_offset=right_offset, + ) - print(l_speakers) + num_new_frames = feat_lengths[0].item() + + # Get predictions for the current chunk from the end of total_preds + preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy() + active_speakers = np.argmax(preds_np, axis=1) + + for idx, spk in enumerate(active_speakers): + start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s) + end_time = start_time + frame_duration_s + + if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1: + self.speaker_segments[-1].end = end_time + else: + self.speaker_segments.append(SpeakerSegment( + speaker=int(spk + 1), + start=start_time, + end=end_time + )) + + self.processed_signal_offset += processed_signal_length + + + def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list: + """ + Assign speakers to tokens based on timing overlap with speaker segments. + """ + for token in tokens: + for segment in self.speaker_segments: + if not (segment.end <= token.start or segment.start >= token.end): + token.speaker = segment.speaker + return tokens + + def close(self): + """ + Cleanup resources. + """ + logger.info("Closing SortformerDiarization.") if __name__ == '__main__': import librosa an4_audio = 'new_audio_test.mp3' - signal, sr = librosa.load(an4_audio,sr=16000) + signal, sr = librosa.load(an4_audio, sr=16000) + diarization_pipeline = SortformerDiarization() - processed_signal, processed_signal_length = prepare_audio_signal(signal) - streaming_loader = create_streaming_loader(processed_signal, processed_signal_length) - process_diarization(streaming_loader) \ No newline at end of file + # Simulate streaming + chunk_size = 16000 # 1 second + for i in range(0, len(signal), chunk_size): + chunk = signal[i:i+chunk_size] + import asyncio + asyncio.run(diarization_pipeline.diarize(chunk)) + + for segment in diarization_pipeline.speaker_segments: + print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s") \ No newline at end of file diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 8335f55..d8c903c 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -58,6 +58,14 @@ def parse_args(): help="Hugging Face model ID for pyannote.audio embedding model.", ) + parser.add_argument( + "--diarization-backend", + type=str, + default="sortformer", + choices=["sortformer", "diart"], + help="The diarization backend to use.", + ) + parser.add_argument( "--no-transcription", action="store_true",