diff --git a/README.md b/README.md index 9c00570..8c7b26d 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str 5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon. +6. **Diarization (beta)**: Adds speaker labeling in real-time alongside transcription using the [Diart](https://github.com/juanmc2005/diart) library. Each transcription segment is tagged with a speaker. +  ## Code Origins @@ -64,6 +66,9 @@ This project reuses and extends code from the original Whisper Streaming reposit # If you want to run the server using uvicorn (recommended) uvicorn + + # If you want to use diarization + diart ``` @@ -76,6 +81,8 @@ This project reuses and extends code from the original Whisper Streaming reposit - `--host` and `--port` let you specify the server’s IP/port. - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data. - For a full list of configurable options, run `python whisper_fastapi_online_server.py -h` + - `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel + - For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme. 4. **Open the Provided HTML**: diff --git a/src/diarization/diarization_online.py b/src/diarization/diarization_online.py new file mode 100644 index 0000000..432e104 --- /dev/null +++ b/src/diarization/diarization_online.py @@ -0,0 +1,110 @@ +from diart import SpeakerDiarization +from diart.inference import StreamingInference +from diart.sources import AudioSource +from rx.subject import Subject +import threading +import numpy as np +import asyncio + +class WebSocketAudioSource(AudioSource): + """ + Simple custom AudioSource that blocks in read() + until close() is called. + push_audio() is used to inject new PCM chunks. + """ + def __init__(self, uri: str = "websocket", sample_rate: int = 16000): + super().__init__(uri, sample_rate) + self._close_event = threading.Event() + self._closed = False + + def read(self): + self._close_event.wait() + + def close(self): + if not self._closed: + self._closed = True + self.stream.on_completed() + self._close_event.set() + + def push_audio(self, chunk: np.ndarray): + chunk = np.expand_dims(chunk, axis=0) + if not self._closed: + self.stream.on_next(chunk) + + +def create_pipeline(SAMPLE_RATE): + diar_pipeline = SpeakerDiarization() + ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE) + inference = StreamingInference( + pipeline=diar_pipeline, + source=ws_source, + do_plot=False, + show_progress=False, + ) + return inference, ws_source + + +def init_diart(SAMPLE_RATE): + inference, ws_source = create_pipeline(SAMPLE_RATE) + + def diar_hook(result): + """ + Hook called each time Diart processes a chunk. + result is (annotation, audio). + We store the label of the last segment in 'current_speaker'. + """ + global l_speakers + l_speakers = [] + annotation, audio = result + for speaker in annotation._labels: + segments_beg = annotation._labels[speaker].segments_boundaries_[0] + segments_end = annotation._labels[speaker].segments_boundaries_[-1] + asyncio.create_task( + l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end}) + ) + + l_speakers_queue = asyncio.Queue() + inference.attach_hooks(diar_hook) + + # Launch Diart in a background thread + loop = asyncio.get_event_loop() + diar_future = loop.run_in_executor(None, inference) + return inference, l_speakers_queue, ws_source + + +class DiartDiarization(): + def __init__(self, SAMPLE_RATE): + self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE) + self.segment_speakers = [] + + async def diarize(self, pcm_array): + self.ws_source.push_audio(pcm_array) + self.segment_speakers = [] + while not self.l_speakers_queue.empty(): + self.segment_speakers.append(await self.l_speakers_queue.get()) + + def close(self): + self.ws_source.close() + + + def assign_speakers_to_chunks(self, chunks): + """ + Go through each chunk and see which speaker(s) overlap + that chunk's time range in the Diart annotation. + Then store the speaker label(s) (or choose the most overlapping). + This modifies `chunks` in-place or returns a new list with assigned speakers. + """ + if not self.segment_speakers: + return chunks + + for segment in self.segment_speakers: + seg_beg = segment["beg"] + seg_end = segment["end"] + speaker = segment["speaker"] + for ch in chunks: + if seg_end <= ch["beg"] or seg_beg >= ch["end"]: + continue + # We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation) + ch["speaker"] = speaker + + return chunks diff --git a/src/web/demo.png b/src/web/demo.png index 8add3fc..53522be 100644 Binary files a/src/web/demo.png and b/src/web/demo.png differ diff --git a/src/web/live_transcription.html b/src/web/live_transcription.html index 8b8aa10..0c6d0ba 100644 --- a/src/web/live_transcription.html +++ b/src/web/live_transcription.html @@ -7,8 +7,8 @@
+