From 2ced4fef2040bb157847b7a9081e9f68c048d0ab Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Mon, 24 Feb 2025 00:35:42 +0100 Subject: [PATCH] diarization now works at word - not chunk - level! --- src/diarization/diarization_online.py | 11 ++--- src/whisper_streaming/online_asr.py | 2 +- src/whisper_streaming/timed_objects.py | 3 +- whisper_fastapi_online_server.py | 64 +++++++++++--------------- 4 files changed, 36 insertions(+), 44 deletions(-) diff --git a/src/diarization/diarization_online.py b/src/diarization/diarization_online.py index 1fcc436..88b3539 100644 --- a/src/diarization/diarization_online.py +++ b/src/diarization/diarization_online.py @@ -81,11 +81,10 @@ class DiartDiarization: def close(self): self.source.close() - def assign_speakers_to_chunks(self, chunks: list) -> list: - end_attributed_speaker = 0 - for chunk in chunks: + def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list: + for token in tokens: for segment in self.segment_speakers: - if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]): - chunk["speaker"] = extract_number(segment["speaker"]) + 1 - end_attributed_speaker = chunk["end"] + if not (segment["end"] <= token.start or segment["beg"] >= token.end): + token.speaker = extract_number(segment["speaker"]) + 1 + end_attributed_speaker = max(token.end, end_attributed_speaker) return end_attributed_speaker diff --git a/src/whisper_streaming/online_asr.py b/src/whisper_streaming/online_asr.py index 2aebe99..6e1e4db 100644 --- a/src/whisper_streaming/online_asr.py +++ b/src/whisper_streaming/online_asr.py @@ -202,7 +202,7 @@ class OnlineASRProcessor: logger.debug( f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" ) - return self.concatenate_tokens(committed_tokens) + return committed_tokens def chunk_completed_sentence(self): """ diff --git a/src/whisper_streaming/timed_objects.py b/src/whisper_streaming/timed_objects.py index 19f3d4f..d6f7b36 100644 --- a/src/whisper_streaming/timed_objects.py +++ b/src/whisper_streaming/timed_objects.py @@ -5,7 +5,8 @@ from typing import Optional class TimedText: start: Optional[float] end: Optional[float] - text: str + text: Optional[str] = '' + speaker: Optional[int] = -1 @dataclass class ASRToken(TimedText): diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index ce47e7e..94914a8 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args +from src.whisper_streaming.timed_objects import ASRToken import math import logging @@ -47,7 +48,7 @@ parser.add_argument( parser.add_argument( "--diarization", type=bool, - default=False, + default=True, help="Whether to enable speaker diarization.", ) @@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket): full_transcription = "" beg = time() beg_loop = time() - chunk_history = [] # Will store dicts: {beg, end, text, speaker} + tokens = [] + end_attributed_speaker = 0 + sep = online.asr.sep while True: try: @@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket): logger.warning("FFmpeg read timeout. Restarting...") await restart_ffmpeg() full_transcription = "" - chunk_history = [] beg = time() continue # Skip processing and read from new process @@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket): if args.transcription: logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") online.insert_audio_chunk(pcm_array) - transcription = online.process_iter() - if transcription.start: - chunk_history.append({ - "beg": transcription.start, - "end": transcription.end, - "text": transcription.text, - "speaker": -1 - }) - full_transcription += transcription.text if transcription else "" + new_tokens = online.process_iter() + tokens.extend(new_tokens) + full_transcription += sep.join([t.text for t in new_tokens]) buffer = online.get_buffer() if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed buffer = "" else: - chunk_history.append({ - "beg": time() - beg_loop, - "end": time() - beg_loop + 1, - "text": '', - "speaker": -1 - }) - sleep(1) + tokens.append( + ASRToken( + start = time() - beg_loop, + end = time() - beg_loop + 0.5)) + sleep(0.5) buffer = '' if args.diarization: await diarization.diarize(pcm_array) - end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history) - + end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens) - current_speaker = -10 + previous_speaker = -10 lines = [] last_end_diarized = 0 - previous_speaker = -1 - for ind, ch in enumerate(chunk_history): - speaker = ch.get("speaker") + for token in tokens: + speaker = token.speaker if args.diarization: if speaker == -1 or speaker == 0: - if ch['end'] < end_attributed_speaker: + if token.end < end_attributed_speaker: speaker = previous_speaker else: speaker = 0 else: - last_end_diarized = max(ch['end'], last_end_diarized) + last_end_diarized = max(token.end, last_end_diarized) - if speaker != current_speaker: + if speaker != previous_speaker: lines.append( { "speaker": speaker, - "text": ch['text'], - "beg": format_time(ch['beg']), - "end": format_time(ch['end']), - "diff": round(ch['end'] - last_end_diarized, 2) + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) } ) - current_speaker = speaker + previous_speaker = speaker else: - lines[-1]["text"] += ch['text'] - lines[-1]["end"] = format_time(ch['end']) - lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2) + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) response = {"lines": lines, "buffer": buffer} await websocket.send_json(response)