Diarization : Uses a rx observer instead of diart attach_hooks method

This commit is contained in:
Quentin Fuxa
2025-03-13 12:02:18 +01:00
parent 7b582f3f9f
commit 3024a9bdb2
3 changed files with 107 additions and 43 deletions

View File

@@ -2,16 +2,79 @@ import asyncio
import re
import threading
import numpy as np
import logging
from diart import SpeakerDiarization
from diart.inference import StreamingInference
from diart.sources import AudioSource
from timed_objects import SpeakerSegment
from diart.sources import MicrophoneAudioSource
from rx.core import Observer
from typing import Tuple, Any, List
from pyannote.core import Annotation
logger = logging.getLogger(__name__)
def extract_number(s: str) -> int:
m = re.search(r'\d+', s)
return int(m.group()) if m else None
class DiarizationObserver(Observer):
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
def __init__(self):
self.speaker_segments = []
self.processed_time = 0
self.segment_lock = threading.Lock()
def on_next(self, value: Tuple[Annotation, Any]):
annotation, audio = value
logger.debug("\n--- New Diarization Result ---")
duration = audio.extent.end - audio.extent.start
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
logger.debug(f"Audio shape: {audio.data.shape}")
with self.segment_lock:
if audio.extent.end > self.processed_time:
self.processed_time = audio.extent.end
if annotation and len(annotation._labels) > 0:
logger.debug("\nSpeaker segments:")
for speaker, label in annotation._labels.items():
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
self.speaker_segments.append(SpeakerSegment(
speaker=speaker,
start=start,
end=end
))
else:
logger.debug("\nNo speakers detected in this segment")
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""
with self.segment_lock:
return self.speaker_segments.copy()
def clear_old_segments(self, older_than: float = 30.0):
"""Clear segments older than the specified time."""
with self.segment_lock:
current_time = self.processed_time
self.speaker_segments = [
segment for segment in self.speaker_segments
if current_time - segment.end < older_than
]
def on_error(self, error):
"""Handle an error in the stream."""
logger.debug(f"Error in diarization stream: {error}")
def on_completed(self):
"""Handle the completion of the stream."""
logger.debug("Diarization stream completed")
class WebSocketAudioSource(AudioSource):
"""
@@ -34,57 +97,57 @@ class WebSocketAudioSource(AudioSource):
def push_audio(self, chunk: np.ndarray):
if not self._closed:
self.stream.on_next(np.expand_dims(chunk, axis=0))
new_audio = np.expand_dims(chunk, axis=0)
logger.debug('Add new chunk with shape:', new_audio.shape)
self.stream.on_next(new_audio)
class DiartDiarization:
def __init__(self, sample_rate: int):
self.processed_time = 0
self.segment_speakers = []
self.speakers_queue = asyncio.Queue()
self.pipeline = SpeakerDiarization()
self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
def __init__(self, sample_rate: int, use_microphone: bool = False):
self.pipeline = SpeakerDiarization()
self.observer = DiarizationObserver()
if use_microphone:
self.source = MicrophoneAudioSource()
self.custom_source = None
else:
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
self.source = self.custom_source
self.inference = StreamingInference(
pipeline=self.pipeline,
source=self.source,
do_plot=False,
show_progress=False,
)
# Attache la fonction hook et démarre l'inférence en arrière-plan.
self.inference.attach_hooks(self._diar_hook)
self.inference.attach_observers(self.observer)
asyncio.get_event_loop().run_in_executor(None, self.inference)
def _diar_hook(self, result):
annotation, audio = result
if annotation._labels:
for speaker, label in annotation._labels.items():
start = label.segments_boundaries_[0]
end = label.segments_boundaries_[-1]
if end > self.processed_time:
self.processed_time = end
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
speaker=speaker,
start=start,
end=end,
)))
else:
dur = audio.extent.end
if dur > self.processed_time:
self.processed_time = dur
async def diarize(self, pcm_array: np.ndarray):
self.source.push_audio(pcm_array)
self.segment_speakers.clear()
while not self.speakers_queue.empty():
self.segment_speakers.append(await self.speakers_queue.get())
"""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
self.observer.clear_old_segments()
return self.observer.get_segments()
def close(self):
self.source.close()
"""Close the audio source."""
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
"""
segments = self.observer.get_segments()
for token in tokens:
for segment in self.segment_speakers:
for segment in segments:
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
return end_attributed_speaker
return end_attributed_speaker

View File

@@ -8,6 +8,7 @@ class TimedText:
text: Optional[str] = ''
speaker: Optional[int] = -1
probability: Optional[float] = None
is_dummy: Optional[bool] = False
@dataclass
class ASRToken(TimedText):

View File

@@ -49,7 +49,7 @@ parser.add_argument(
parser.add_argument(
"--confidence-validation",
type=bool,
default=True,
default=False,
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
)
@@ -110,9 +110,10 @@ class SharedState:
current_time = time() - self.beg_loop
dummy_token = ASRToken(
start=current_time,
end=current_time + 0.5,
text="",
speaker=-1
end=current_time + 1,
text=".",
speaker=-1,
is_dummy=True
)
self.tokens.append(dummy_token)
@@ -275,14 +276,13 @@ async def results_formatter(shared_state, websocket):
sep = state["sep"]
# If diarization is enabled but no transcription, add dummy tokens periodically
if not tokens and not args.transcription and args.diarization:
if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization:
await shared_state.add_dummy_token()
# Re-fetch tokens after adding dummy
sleep(0.5)
state = await shared_state.get_current_state()
tokens = state["tokens"]
# Process tokens to create response
previous_speaker = -10
previous_speaker = -1
lines = []
last_end_diarized = 0
undiarized_text = []