mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
288 lines
12 KiB
Python
288 lines
12 KiB
Python
import asyncio
|
|
import logging
|
|
import re
|
|
import threading
|
|
import time
|
|
from queue import Empty, SimpleQueue
|
|
from typing import Any, List, Tuple
|
|
|
|
import diart.models as m
|
|
import numpy as np
|
|
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
|
from diart.inference import StreamingInference
|
|
from diart.sources import AudioSource, MicrophoneAudioSource
|
|
from pyannote.core import Annotation
|
|
from rx.core import Observer
|
|
|
|
from whisperlivekit.timed_objects import SpeakerSegment
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def extract_number(s: str) -> int:
|
|
m = re.search(r'\d+', s)
|
|
return int(m.group()) if m else None
|
|
|
|
class DiarizationObserver(Observer):
|
|
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
|
|
|
def __init__(self):
|
|
self.diarization_segments = []
|
|
self.processed_time = 0
|
|
self.segment_lock = threading.Lock()
|
|
self.global_time_offset = 0.0
|
|
|
|
def on_next(self, value: Tuple[Annotation, Any]):
|
|
annotation, audio = value
|
|
|
|
logger.debug("\n--- New Diarization Result ---")
|
|
|
|
duration = audio.extent.end - audio.extent.start
|
|
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
|
logger.debug(f"Audio shape: {audio.data.shape}")
|
|
|
|
with self.segment_lock:
|
|
if audio.extent.end > self.processed_time:
|
|
self.processed_time = audio.extent.end
|
|
if annotation and len(annotation._labels) > 0:
|
|
logger.debug("\nSpeaker segments:")
|
|
for speaker, label in annotation._labels.items():
|
|
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
|
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
|
self.diarization_segments.append(SpeakerSegment(
|
|
speaker=speaker,
|
|
start=start + self.global_time_offset,
|
|
end=end + self.global_time_offset
|
|
))
|
|
else:
|
|
logger.debug("\nNo speakers detected in this segment")
|
|
|
|
def get_segments(self) -> List[SpeakerSegment]:
|
|
"""Get a copy of the current speaker segments."""
|
|
with self.segment_lock:
|
|
return self.diarization_segments.copy()
|
|
|
|
def clear_old_segments(self, older_than: float = 30.0):
|
|
"""Clear segments older than the specified time."""
|
|
with self.segment_lock:
|
|
current_time = self.processed_time
|
|
self.diarization_segments = [
|
|
segment for segment in self.diarization_segments
|
|
if current_time - segment.end < older_than
|
|
]
|
|
|
|
def on_error(self, error):
|
|
"""Handle an error in the stream."""
|
|
logger.debug(f"Error in diarization stream: {error}")
|
|
|
|
def on_completed(self):
|
|
"""Handle the completion of the stream."""
|
|
logger.debug("Diarization stream completed")
|
|
|
|
|
|
class WebSocketAudioSource(AudioSource):
|
|
"""
|
|
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
|
"""
|
|
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
|
super().__init__(uri, sample_rate)
|
|
self.block_duration = block_duration
|
|
self.block_size = int(np.rint(block_duration * sample_rate))
|
|
self._queue = SimpleQueue()
|
|
self._buffer = np.array([], dtype=np.float32)
|
|
self._buffer_lock = threading.Lock()
|
|
self._closed = False
|
|
self._close_event = threading.Event()
|
|
self._processing_thread = None
|
|
self._last_chunk_time = time.time()
|
|
|
|
def read(self):
|
|
"""Start processing buffered audio and emit fixed-size chunks."""
|
|
self._processing_thread = threading.Thread(target=self._process_chunks)
|
|
self._processing_thread.daemon = True
|
|
self._processing_thread.start()
|
|
|
|
self._close_event.wait()
|
|
if self._processing_thread:
|
|
self._processing_thread.join(timeout=2.0)
|
|
|
|
def _process_chunks(self):
|
|
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
|
while not self._closed:
|
|
try:
|
|
audio_chunk = self._queue.get(timeout=0.1)
|
|
|
|
with self._buffer_lock:
|
|
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
|
|
|
while len(self._buffer) >= self.block_size:
|
|
chunk = self._buffer[:self.block_size]
|
|
self._buffer = self._buffer[self.block_size:]
|
|
|
|
current_time = time.time()
|
|
time_since_last = current_time - self._last_chunk_time
|
|
if time_since_last < self.block_duration:
|
|
time.sleep(self.block_duration - time_since_last)
|
|
|
|
chunk_reshaped = chunk.reshape(1, -1)
|
|
self.stream.on_next(chunk_reshaped)
|
|
self._last_chunk_time = time.time()
|
|
|
|
except Empty:
|
|
with self._buffer_lock:
|
|
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
|
padded_chunk[:len(self._buffer)] = self._buffer
|
|
self._buffer = np.array([], dtype=np.float32)
|
|
|
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
|
self.stream.on_next(chunk_reshaped)
|
|
self._last_chunk_time = time.time()
|
|
except Exception as e:
|
|
logger.error(f"Error in audio processing thread: {e}")
|
|
self.stream.on_error(e)
|
|
break
|
|
|
|
with self._buffer_lock:
|
|
if len(self._buffer) > 0:
|
|
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
|
padded_chunk[:len(self._buffer)] = self._buffer
|
|
chunk_reshaped = padded_chunk.reshape(1, -1)
|
|
self.stream.on_next(chunk_reshaped)
|
|
|
|
self.stream.on_completed()
|
|
|
|
def close(self):
|
|
if not self._closed:
|
|
self._closed = True
|
|
self._close_event.set()
|
|
|
|
def push_audio(self, chunk: np.ndarray):
|
|
"""Add audio chunk to the processing queue."""
|
|
if not self._closed:
|
|
if chunk.ndim > 1:
|
|
chunk = chunk.flatten()
|
|
self._queue.put(chunk)
|
|
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
|
|
|
|
|
class DiartDiarization:
|
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
|
|
|
if config is None:
|
|
config = SpeakerDiarizationConfig(
|
|
segmentation=segmentation_model,
|
|
embedding=embedding_model,
|
|
)
|
|
|
|
self.pipeline = SpeakerDiarization(config=config)
|
|
self.observer = DiarizationObserver()
|
|
|
|
if use_microphone:
|
|
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
|
self.custom_source = None
|
|
else:
|
|
self.custom_source = WebSocketAudioSource(
|
|
uri="websocket_source",
|
|
sample_rate=sample_rate,
|
|
block_duration=block_duration
|
|
)
|
|
self.source = self.custom_source
|
|
|
|
self.inference = StreamingInference(
|
|
pipeline=self.pipeline,
|
|
source=self.source,
|
|
do_plot=False,
|
|
show_progress=False,
|
|
)
|
|
self.inference.attach_observers(self.observer)
|
|
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
|
|
|
def insert_silence(self, silence_duration):
|
|
self.observer.global_time_offset += silence_duration
|
|
|
|
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
|
"""Buffer audio for the next diarization step."""
|
|
if self.custom_source:
|
|
self.custom_source.push_audio(pcm_array)
|
|
|
|
async def diarize(self):
|
|
"""Return the current speaker segments from the diarization pipeline."""
|
|
return self.observer.get_segments()
|
|
|
|
def close(self):
|
|
"""Close the audio source."""
|
|
if self.custom_source:
|
|
self.custom_source.close()
|
|
|
|
|
|
def concatenate_speakers(segments):
|
|
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
|
for segment in segments:
|
|
speaker = extract_number(segment.speaker) + 1
|
|
if segments_concatenated[-1]['speaker'] != speaker:
|
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
|
else:
|
|
segments_concatenated[-1]['end'] = segment.end
|
|
# print("Segments concatenated:")
|
|
# for entry in segments_concatenated:
|
|
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
|
return segments_concatenated
|
|
|
|
|
|
def add_speaker_to_tokens(segments, tokens):
|
|
"""
|
|
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
|
"""
|
|
punctuation_marks = {'.', '!', '?'}
|
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
|
segments_concatenated = concatenate_speakers(segments)
|
|
for ind, segment in enumerate(segments_concatenated):
|
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
|
if punctuation_token.start > segment['end']:
|
|
after_length = punctuation_token.start - segment['end']
|
|
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
|
if before_length > after_length:
|
|
segment['end'] = punctuation_token.start
|
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
|
else:
|
|
segment['end'] = punctuation_tokens[i - 1].end
|
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
|
break
|
|
|
|
last_end = 0.0
|
|
for token in tokens:
|
|
start = max(last_end + 0.01, token.start)
|
|
token.start = start
|
|
token.end = max(start, token.end)
|
|
last_end = token.end
|
|
|
|
ind_last_speaker = 0
|
|
for segment in segments_concatenated:
|
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
|
if token.end <= segment['end']:
|
|
token.speaker = segment['speaker']
|
|
ind_last_speaker = i + 1
|
|
# print(
|
|
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
|
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
|
# )
|
|
elif token.start > segment['end']:
|
|
break
|
|
return tokens
|
|
|
|
|
|
def visualize_tokens(tokens):
|
|
conversation = [{"speaker": -1, "text": ""}]
|
|
for token in tokens:
|
|
speaker = conversation[-1]['speaker']
|
|
if token.speaker != speaker:
|
|
conversation.append({"speaker": token.speaker, "text": token.text})
|
|
else:
|
|
conversation[-1]['text'] += token.text
|
|
print("Conversation:")
|
|
for entry in conversation:
|
|
print(f"Speaker {entry['speaker']}: {entry['text']}") |