diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 850fd0a..8c9c8c0 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -6,7 +6,7 @@ import logging import traceback from datetime import timedelta from whisperlivekit.timed_objects import ASRToken, Silence -from whisperlivekit.core import TranscriptionEngine, online_factory +from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.remove_silences import handle_silences from whisperlivekit.trail_repetition import trim_tail_repetition @@ -63,7 +63,6 @@ class AudioProcessor: # Models and processing self.asr = models.asr self.tokenizer = models.tokenizer - self.diarization = models.diarization self.vac_model = models.vac_model if self.args.vac: self.vac = FixedVADIterator(models.vac_model) @@ -96,6 +95,11 @@ class AudioProcessor: # Initialize transcription engine if enabled if self.args.transcription: self.online = online_factory(self.args, models.asr, models.tokenizer) + + # Initialize diarization engine if enabled + if self.args.diarization: + self.diarization = online_diarization_factory(self.args, models.diarization_model) + def convert_pcm_to_float(self, pcm_buffer): """Convert PCM buffer in s16le format to normalized NumPy array.""" diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 97f07fa..bcab83f 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -121,14 +121,14 @@ class TranscriptionEngine: if self.args.diarization: if self.args.diarization_backend == "diart": from whisperlivekit.diarization.diart_backend import DiartDiarization - self.diarization = DiartDiarization( + self.diarization_model = DiartDiarization( block_duration=self.args.min_chunk_size, segmentation_model_name=self.args.segmentation_model, embedding_model_name=self.args.embedding_model ) elif self.args.diarization_backend == "sortformer": from whisperlivekit.diarization.sortformer_backend import SortformerDiarization - self.diarization = SortformerDiarization() + self.diarization_model = SortformerDiarization() else: raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}") @@ -153,4 +153,16 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr): confidence_validation = args.confidence_validation ) return online - \ No newline at end of file + + +def online_diarization_factory(args, diarization_backend): + if args.diarization_backend == "diart": + online = diarization_backend + # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded + + if args.diarization_backend == "sortformer": + from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline + online = SortformerDiarizationOnline(shared_model=diarization_backend) + return online + + \ No newline at end of file diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index a55931c..dd35cfd 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -48,39 +48,12 @@ class StreamingSortformerState: class SortformerDiarization: - def __init__(self, sample_rate: int = 16000, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"): + def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"): """ - Initialize the streaming Sortformer diarization system. - - Args: - sample_rate: Audio sample rate (default: 16000) - model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") + Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized. """ - self.sample_rate = sample_rate - self.speaker_segments = [] - self.buffer_audio = np.array([], dtype=np.float32) - self.segment_lock = threading.Lock() - self.global_time_offset = 0.0 - self.processed_time = 0.0 - self.debug = False - self._load_model(model_name) - - self._init_streaming_state() - - self._previous_chunk_features = None - self._chunk_index = 0 - self._len_prediction = None - - # Audio buffer to store PCM chunks for debugging - self.audio_buffer = [] - - # Buffer for accumulating audio chunks until reaching chunk_duration_seconds - self.audio_chunk_buffer = [] - self.accumulated_duration = 0.0 - - logger.info("SortformerDiarization initialized successfully") - + def _load_model(self, model_name: str): """Load and configure the Sortformer model for streaming.""" try: @@ -102,26 +75,59 @@ class SortformerDiarization: self.diar_model.sortformer_modules.spkcache_update_period = 144 self.diar_model.sortformer_modules.log = False self.diar_model.sortformer_modules._check_streaming_parameters() - - self.audio2mel = AudioToMelSpectrogramPreprocessor( - window_size=0.025, - normalize="NA", - n_fft=512, - features=128, - pad_to=0 - ) - - self.chunk_duration_seconds = ( - self.diar_model.sortformer_modules.chunk_len * - self.diar_model.sortformer_modules.subsampling_factor * - self.diar_model.preprocessor._cfg.window_stride - ) - - logger.info(f"Chunk duration: {self.chunk_duration_seconds:.2f}s") - + except Exception as e: logger.error(f"Failed to load Sortformer model: {e}") raise + +class SortformerDiarizationOnline: + def __init__(self, shared_model, sample_rate: int = 16000): + """ + Initialize the streaming Sortformer diarization system. + + Args: + sample_rate: Audio sample rate (default: 16000) + model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") + """ + self.sample_rate = sample_rate + self.speaker_segments = [] + self.buffer_audio = np.array([], dtype=np.float32) + self.segment_lock = threading.Lock() + self.global_time_offset = 0.0 + self.processed_time = 0.0 + self.debug = False + + self.diar_model = shared_model.diar_model + + self.audio2mel = AudioToMelSpectrogramPreprocessor( + window_size=0.025, + normalize="NA", + n_fft=512, + features=128, + pad_to=0 + ) + + self.chunk_duration_seconds = ( + self.diar_model.sortformer_modules.chunk_len * + self.diar_model.sortformer_modules.subsampling_factor * + self.diar_model.preprocessor._cfg.window_stride + ) + + self._init_streaming_state() + + self._previous_chunk_features = None + self._chunk_index = 0 + self._len_prediction = None + + # Audio buffer to store PCM chunks for debugging + self.audio_buffer = [] + + # Buffer for accumulating audio chunks until reaching chunk_duration_seconds + self.audio_chunk_buffer = [] + self.accumulated_duration = 0.0 + + logger.info("SortformerDiarization initialized successfully") + def _init_streaming_state(self): """Initialize the streaming state for the model."""