diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py index 9dafce4..d04c3a2 100644 --- a/diarization/diarization_online.py +++ b/diarization/diarization_online.py @@ -2,16 +2,79 @@ import asyncio import re import threading import numpy as np +import logging + from diart import SpeakerDiarization from diart.inference import StreamingInference from diart.sources import AudioSource from timed_objects import SpeakerSegment +from diart.sources import MicrophoneAudioSource +from rx.core import Observer +from typing import Tuple, Any, List +from pyannote.core import Annotation + +logger = logging.getLogger(__name__) def extract_number(s: str) -> int: m = re.search(r'\d+', s) return int(m.group()) if m else None +class DiarizationObserver(Observer): + """Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" + + def __init__(self): + self.speaker_segments = [] + self.processed_time = 0 + self.segment_lock = threading.Lock() + + def on_next(self, value: Tuple[Annotation, Any]): + annotation, audio = value + + logger.debug("\n--- New Diarization Result ---") + + duration = audio.extent.end - audio.extent.start + logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)") + logger.debug(f"Audio shape: {audio.data.shape}") + + with self.segment_lock: + if audio.extent.end > self.processed_time: + self.processed_time = audio.extent.end + if annotation and len(annotation._labels) > 0: + logger.debug("\nSpeaker segments:") + for speaker, label in annotation._labels.items(): + for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): + print(f" {speaker}: {start:.2f}s-{end:.2f}s") + self.speaker_segments.append(SpeakerSegment( + speaker=speaker, + start=start, + end=end + )) + else: + logger.debug("\nNo speakers detected in this segment") + + def get_segments(self) -> List[SpeakerSegment]: + """Get a copy of the current speaker segments.""" + with self.segment_lock: + return self.speaker_segments.copy() + + def clear_old_segments(self, older_than: float = 30.0): + """Clear segments older than the specified time.""" + with self.segment_lock: + current_time = self.processed_time + self.speaker_segments = [ + segment for segment in self.speaker_segments + if current_time - segment.end < older_than + ] + + def on_error(self, error): + """Handle an error in the stream.""" + logger.debug(f"Error in diarization stream: {error}") + + def on_completed(self): + """Handle the completion of the stream.""" + logger.debug("Diarization stream completed") + class WebSocketAudioSource(AudioSource): """ @@ -34,57 +97,57 @@ class WebSocketAudioSource(AudioSource): def push_audio(self, chunk: np.ndarray): if not self._closed: - self.stream.on_next(np.expand_dims(chunk, axis=0)) + new_audio = np.expand_dims(chunk, axis=0) + logger.debug('Add new chunk with shape:', new_audio.shape) + self.stream.on_next(new_audio) class DiartDiarization: - def __init__(self, sample_rate: int): - self.processed_time = 0 - self.segment_speakers = [] - self.speakers_queue = asyncio.Queue() - self.pipeline = SpeakerDiarization() - self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) + def __init__(self, sample_rate: int, use_microphone: bool = False): + self.pipeline = SpeakerDiarization() + self.observer = DiarizationObserver() + + if use_microphone: + self.source = MicrophoneAudioSource() + self.custom_source = None + else: + self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) + self.source = self.custom_source + self.inference = StreamingInference( pipeline=self.pipeline, source=self.source, do_plot=False, show_progress=False, ) - # Attache la fonction hook et démarre l'inférence en arrière-plan. - self.inference.attach_hooks(self._diar_hook) + self.inference.attach_observers(self.observer) asyncio.get_event_loop().run_in_executor(None, self.inference) - def _diar_hook(self, result): - annotation, audio = result - if annotation._labels: - for speaker, label in annotation._labels.items(): - start = label.segments_boundaries_[0] - end = label.segments_boundaries_[-1] - if end > self.processed_time: - self.processed_time = end - asyncio.create_task(self.speakers_queue.put(SpeakerSegment( - speaker=speaker, - start=start, - end=end, - ))) - else: - dur = audio.extent.end - if dur > self.processed_time: - self.processed_time = dur - async def diarize(self, pcm_array: np.ndarray): - self.source.push_audio(pcm_array) - self.segment_speakers.clear() - while not self.speakers_queue.empty(): - self.segment_speakers.append(await self.speakers_queue.get()) + """ + Process audio data for diarization. + Only used when working with WebSocketAudioSource. + """ + if self.custom_source: + self.custom_source.push_audio(pcm_array) + self.observer.clear_old_segments() + return self.observer.get_segments() def close(self): - self.source.close() + """Close the audio source.""" + if self.custom_source: + self.custom_source.close() - def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list: + def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float: + """ + Assign speakers to tokens based on timing overlap with speaker segments. + Uses the segments collected by the observer. + """ + segments = self.observer.get_segments() + for token in tokens: - for segment in self.segment_speakers: + for segment in segments: if not (segment.end <= token.start or segment.start >= token.end): token.speaker = extract_number(segment.speaker) + 1 end_attributed_speaker = max(token.end, end_attributed_speaker) - return end_attributed_speaker + return end_attributed_speaker \ No newline at end of file diff --git a/timed_objects.py b/timed_objects.py index c1c3e4e..c347b9a 100644 --- a/timed_objects.py +++ b/timed_objects.py @@ -8,6 +8,7 @@ class TimedText: text: Optional[str] = '' speaker: Optional[int] = -1 probability: Optional[float] = None + is_dummy: Optional[bool] = False @dataclass class ASRToken(TimedText): diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index d55e104..8a08381 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -49,7 +49,7 @@ parser.add_argument( parser.add_argument( "--confidence-validation", type=bool, - default=True, + default=False, help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", ) @@ -110,9 +110,10 @@ class SharedState: current_time = time() - self.beg_loop dummy_token = ASRToken( start=current_time, - end=current_time + 0.5, - text="", - speaker=-1 + end=current_time + 1, + text=".", + speaker=-1, + is_dummy=True ) self.tokens.append(dummy_token) @@ -275,14 +276,13 @@ async def results_formatter(shared_state, websocket): sep = state["sep"] # If diarization is enabled but no transcription, add dummy tokens periodically - if not tokens and not args.transcription and args.diarization: + if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization: await shared_state.add_dummy_token() - # Re-fetch tokens after adding dummy + sleep(0.5) state = await shared_state.get_current_state() tokens = state["tokens"] - # Process tokens to create response - previous_speaker = -10 + previous_speaker = -1 lines = [] last_end_diarized = 0 undiarized_text = []