mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
153 lines
5.7 KiB
Python
153 lines
5.7 KiB
Python
import asyncio
|
|
import re
|
|
import threading
|
|
import numpy as np
|
|
import logging
|
|
|
|
|
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
|
from diart.inference import StreamingInference
|
|
from diart.sources import AudioSource
|
|
from whisperlivekit.timed_objects import SpeakerSegment
|
|
from diart.sources import MicrophoneAudioSource
|
|
from rx.core import Observer
|
|
from typing import Tuple, Any, List
|
|
from pyannote.core import Annotation
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def extract_number(s: str) -> int:
|
|
m = re.search(r'\d+', s)
|
|
return int(m.group()) if m else None
|
|
|
|
class DiarizationObserver(Observer):
|
|
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
|
|
|
def __init__(self):
|
|
self.speaker_segments = []
|
|
self.processed_time = 0
|
|
self.segment_lock = threading.Lock()
|
|
|
|
def on_next(self, value: Tuple[Annotation, Any]):
|
|
annotation, audio = value
|
|
|
|
logger.debug("\n--- New Diarization Result ---")
|
|
|
|
duration = audio.extent.end - audio.extent.start
|
|
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
|
logger.debug(f"Audio shape: {audio.data.shape}")
|
|
|
|
with self.segment_lock:
|
|
if audio.extent.end > self.processed_time:
|
|
self.processed_time = audio.extent.end
|
|
if annotation and len(annotation._labels) > 0:
|
|
logger.debug("\nSpeaker segments:")
|
|
for speaker, label in annotation._labels.items():
|
|
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
|
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
|
self.speaker_segments.append(SpeakerSegment(
|
|
speaker=speaker,
|
|
start=start,
|
|
end=end
|
|
))
|
|
else:
|
|
logger.debug("\nNo speakers detected in this segment")
|
|
|
|
def get_segments(self) -> List[SpeakerSegment]:
|
|
"""Get a copy of the current speaker segments."""
|
|
with self.segment_lock:
|
|
return self.speaker_segments.copy()
|
|
|
|
def clear_old_segments(self, older_than: float = 30.0):
|
|
"""Clear segments older than the specified time."""
|
|
with self.segment_lock:
|
|
current_time = self.processed_time
|
|
self.speaker_segments = [
|
|
segment for segment in self.speaker_segments
|
|
if current_time - segment.end < older_than
|
|
]
|
|
|
|
def on_error(self, error):
|
|
"""Handle an error in the stream."""
|
|
logger.debug(f"Error in diarization stream: {error}")
|
|
|
|
def on_completed(self):
|
|
"""Handle the completion of the stream."""
|
|
logger.debug("Diarization stream completed")
|
|
|
|
|
|
class WebSocketAudioSource(AudioSource):
|
|
"""
|
|
Custom AudioSource that blocks in read() until close() is called.
|
|
Use push_audio() to inject PCM chunks.
|
|
"""
|
|
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
|
super().__init__(uri, sample_rate)
|
|
self._closed = False
|
|
self._close_event = threading.Event()
|
|
|
|
def read(self):
|
|
self._close_event.wait()
|
|
|
|
def close(self):
|
|
if not self._closed:
|
|
self._closed = True
|
|
self.stream.on_completed()
|
|
self._close_event.set()
|
|
|
|
def push_audio(self, chunk: np.ndarray):
|
|
if not self._closed:
|
|
new_audio = np.expand_dims(chunk, axis=0)
|
|
logger.debug('Add new chunk with shape:', new_audio.shape)
|
|
self.stream.on_next(new_audio)
|
|
|
|
|
|
class DiartDiarization:
|
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
|
self.pipeline = SpeakerDiarization(config=config)
|
|
self.observer = DiarizationObserver()
|
|
|
|
if use_microphone:
|
|
self.source = MicrophoneAudioSource()
|
|
self.custom_source = None
|
|
else:
|
|
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
|
self.source = self.custom_source
|
|
|
|
self.inference = StreamingInference(
|
|
pipeline=self.pipeline,
|
|
source=self.source,
|
|
do_plot=False,
|
|
show_progress=False,
|
|
)
|
|
self.inference.attach_observers(self.observer)
|
|
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
|
|
|
async def diarize(self, pcm_array: np.ndarray):
|
|
"""
|
|
Process audio data for diarization.
|
|
Only used when working with WebSocketAudioSource.
|
|
"""
|
|
if self.custom_source:
|
|
self.custom_source.push_audio(pcm_array)
|
|
self.observer.clear_old_segments()
|
|
return self.observer.get_segments()
|
|
|
|
def close(self):
|
|
"""Close the audio source."""
|
|
if self.custom_source:
|
|
self.custom_source.close()
|
|
|
|
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
|
"""
|
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
|
Uses the segments collected by the observer.
|
|
"""
|
|
segments = self.observer.get_segments()
|
|
|
|
for token in tokens:
|
|
for segment in segments:
|
|
if not (segment.end <= token.start or segment.start >= token.end):
|
|
token.speaker = extract_number(segment.speaker) + 1
|
|
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
|
return end_attributed_speaker |