From 253a080df50b6f4bbe74131a627b27277baa24df Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 19 Aug 2025 21:12:55 +0200 Subject: [PATCH] diart diarization handles pauses/silences thanks to offset --- whisperlivekit/audio_processor.py | 2 +- whisperlivekit/core.py | 7 ++----- whisperlivekit/diarization/diart_backend.py | 8 ++++++-- whisperlivekit/parse_args.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 7d01bc2..ba31e46 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -372,7 +372,7 @@ class AudioProcessor: if type(item) is Silence: cumulative_pcm_duration_stream_time += item.duration - # self.diarization_obj.insert_silence(item.duration, self.tokens[-1].end) + diarization_obj.insert_silence(item.duration) continue if isinstance(item, np.ndarray): diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 1231220..4a3b1ce 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -57,7 +57,7 @@ class TranscriptionEngine: "static_init_prompt": None, "max_context_tokens": None, "model_path": './base.pt', - "diarization_backend": "sortformer", + "diarization_backend": "diart", # diart params: "segmentation_model": "pyannote/segmentation-3.0", "embedding_model": "pyannote/embedding", @@ -127,10 +127,7 @@ class TranscriptionEngine: embedding_model_name=self.args.embedding_model ) elif self.args.diarization_backend == "sortformer": - from whisperlivekit.diarization.sortformer_backend import SortformerDiarization - self.diarization = SortformerDiarization( - model_name="nvidia/diar_streaming_sortformer_4spk-v2" - ) + raise ValueError('Sortformer backend in developement') else: raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}") diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index e42584c..6c578cb 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -29,6 +29,7 @@ class DiarizationObserver(Observer): self.speaker_segments = [] self.processed_time = 0 self.segment_lock = threading.Lock() + self.global_time_offset = 0.0 def on_next(self, value: Tuple[Annotation, Any]): annotation, audio = value @@ -49,8 +50,8 @@ class DiarizationObserver(Observer): print(f" {speaker}: {start:.2f}s-{end:.2f}s") self.speaker_segments.append(SpeakerSegment( speaker=speaker, - start=start, - end=end + start=start + self.global_time_offset, + end=end + self.global_time_offset )) else: logger.debug("\nNo speakers detected in this segment") @@ -199,6 +200,9 @@ class DiartDiarization: self.inference.attach_observers(self.observer) asyncio.get_event_loop().run_in_executor(None, self.inference) + def insert_silence(self, silence_duration): + self.observer.global_time_offset += silence_duration + async def diarize(self, pcm_array: np.ndarray): """ Process audio data for diarization. diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index d8c903c..9e54698 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -61,7 +61,7 @@ def parse_args(): parser.add_argument( "--diarization-backend", type=str, - default="sortformer", + default="diart", choices=["sortformer", "diart"], help="The diarization backend to use.", )