From 2d2a4967e69d67e89cdc1b8cde86735eedc38eb6 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 28 Feb 2025 18:41:12 +0100 Subject: [PATCH] update import paths --- diarization/diarization_online.py | 90 +++++++++++++++++++++++++++++++ whisper_fastapi_online_server.py | 21 +++----- 2 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 diarization/diarization_online.py diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py new file mode 100644 index 0000000..9dafce4 --- /dev/null +++ b/diarization/diarization_online.py @@ -0,0 +1,90 @@ +import asyncio +import re +import threading +import numpy as np + +from diart import SpeakerDiarization +from diart.inference import StreamingInference +from diart.sources import AudioSource +from timed_objects import SpeakerSegment + +def extract_number(s: str) -> int: + m = re.search(r'\d+', s) + return int(m.group()) if m else None + + +class WebSocketAudioSource(AudioSource): + """ + Custom AudioSource that blocks in read() until close() is called. + Use push_audio() to inject PCM chunks. + """ + def __init__(self, uri: str = "websocket", sample_rate: int = 16000): + super().__init__(uri, sample_rate) + self._closed = False + self._close_event = threading.Event() + + def read(self): + self._close_event.wait() + + def close(self): + if not self._closed: + self._closed = True + self.stream.on_completed() + self._close_event.set() + + def push_audio(self, chunk: np.ndarray): + if not self._closed: + self.stream.on_next(np.expand_dims(chunk, axis=0)) + + +class DiartDiarization: + def __init__(self, sample_rate: int): + self.processed_time = 0 + self.segment_speakers = [] + self.speakers_queue = asyncio.Queue() + self.pipeline = SpeakerDiarization() + self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) + self.inference = StreamingInference( + pipeline=self.pipeline, + source=self.source, + do_plot=False, + show_progress=False, + ) + # Attache la fonction hook et démarre l'inférence en arrière-plan. + self.inference.attach_hooks(self._diar_hook) + asyncio.get_event_loop().run_in_executor(None, self.inference) + + def _diar_hook(self, result): + annotation, audio = result + if annotation._labels: + for speaker, label in annotation._labels.items(): + start = label.segments_boundaries_[0] + end = label.segments_boundaries_[-1] + if end > self.processed_time: + self.processed_time = end + asyncio.create_task(self.speakers_queue.put(SpeakerSegment( + speaker=speaker, + start=start, + end=end, + ))) + else: + dur = audio.extent.end + if dur > self.processed_time: + self.processed_time = dur + + async def diarize(self, pcm_array: np.ndarray): + self.source.push_audio(pcm_array) + self.segment_speakers.clear() + while not self.speakers_queue.empty(): + self.segment_speakers.append(await self.speakers_queue.get()) + + def close(self): + self.source.close() + + def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list: + for token in tokens: + for segment in self.segment_speakers: + if not (segment.end <= token.start or segment.start >= token.end): + token.speaker = extract_number(segment.speaker) + 1 + end_attributed_speaker = max(token.end, end_attributed_speaker) + return end_attributed_speaker diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index a58146f..937eb87 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -10,8 +10,8 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware -from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args -from src.whisper_streaming.timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args +from timed_objects import ASRToken import math import logging @@ -49,7 +49,7 @@ parser.add_argument( parser.add_argument( "--diarization", type=bool, - default=False, + default=True, help="Whether to enable speaker diarization.", ) @@ -157,7 +157,7 @@ async def lifespan(app: FastAPI): asr, tokenizer = None, None if args.diarization: - from src.diarization.diarization_online import DiartDiarization + from diarization.diarization_online import DiartDiarization diarization = DiartDiarization(SAMPLE_RATE) else : diarization = None @@ -174,7 +174,7 @@ app.add_middleware( # Load demo HTML for the root endpoint -with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: +with open("web/live_transcription.html", "r", encoding="utf-8") as f: html = f.read() async def start_ffmpeg_decoder(): @@ -277,24 +277,18 @@ async def results_formatter(shared_state, websocket): # Process tokens to create response previous_speaker = -10 - lines = [ - ] + lines = [] last_end_diarized = 0 undiarized_text = [] for token in tokens: speaker = token.speaker - # Handle diarization differently if diarization is enabled if args.diarization: - # If token is not yet processed by diarization if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - # Add this token's text to undiarized buffer instead of creating a new line undiarized_text.append(token.text) continue - # If speaker isn't assigned yet but should be (based on timestamp) elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: speaker = previous_speaker - # Track last diarized token end time if speaker not in [-1, 0]: last_end_diarized = max(token.end, last_end_diarized) @@ -314,7 +308,6 @@ async def results_formatter(shared_state, websocket): lines[-1]["end"] = format_time(token.end) lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - # Update buffer_diarization with undiarized text if undiarized_text: combined_buffer_diarization = sep.join(undiarized_text) if buffer_transcription: @@ -322,7 +315,6 @@ async def results_formatter(shared_state, websocket): await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) buffer_diarization = combined_buffer_diarization - # Prepare response object if lines: response = { "lines": lines, @@ -350,7 +342,6 @@ async def results_formatter(shared_state, websocket): response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization if response_content != shared_state.last_response_content: - # Only send if there's actual content to send if lines or buffer_transcription or buffer_diarization: await websocket.send_json(response) shared_state.last_response_content = response_content