From 14af47e84b5ff7f7fc865e5bd666379b5be16346 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 28 Feb 2025 18:11:36 +0100 Subject: [PATCH] undiarized text is assigned to last speaker, with buffer information; traceback is used to format_exc errors --- whisper_fastapi_online_server.py | 80 +++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index f931775..a58146f 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -16,6 +16,7 @@ from src.whisper_streaming.timed_objects import ASRToken import math import logging from datetime import timedelta +import traceback def format_time(seconds): return str(timedelta(seconds=int(seconds))) @@ -48,7 +49,7 @@ parser.add_argument( parser.add_argument( "--diarization", type=bool, - default=True, + default=False, help="Whether to enable speaker diarization.", ) @@ -81,6 +82,7 @@ class SharedState: self.lock = asyncio.Lock() self.beg_loop = time() self.sep = " " # Default separator + self.last_response_content = "" # To track changes in response async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): async with self.lock: @@ -119,7 +121,7 @@ class SharedState: # Calculate remaining time for diarization if self.end_attributed_speaker > 0: - remaining_time_diarization = max(0, round(current_time - self.beg_loop - self.end_attributed_speaker, 2)) + remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2)) return { "tokens": self.tokens.copy(), @@ -142,6 +144,7 @@ class SharedState: self.end_attributed_speaker = 0 self.full_transcription = "" self.beg_loop = time() + self.last_response_content = "" ##### LOAD APP ##### @@ -221,6 +224,7 @@ async def transcription_processor(shared_state, pcm_queue, online): except Exception as e: logger.warning(f"Exception in transcription_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") finally: pcm_queue.task_done() @@ -247,6 +251,7 @@ async def diarization_processor(shared_state, pcm_queue, diarization_obj): except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") finally: pcm_queue.task_done() @@ -272,21 +277,28 @@ async def results_formatter(shared_state, websocket): # Process tokens to create response previous_speaker = -10 - lines = [] + lines = [ + ] last_end_diarized = 0 + undiarized_text = [] for token in tokens: speaker = token.speaker + # Handle diarization differently if diarization is enabled if args.diarization: - if speaker == -1 or speaker == 0: - if token.end < end_attributed_speaker: - speaker = previous_speaker - else: - speaker = 0 - else: + # If token is not yet processed by diarization + if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: + # Add this token's text to undiarized buffer instead of creating a new line + undiarized_text.append(token.text) + continue + # If speaker isn't assigned yet but should be (based on timestamp) + elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: + speaker = previous_speaker + # Track last diarized token end time + if speaker not in [-1, 0]: last_end_diarized = max(token.end, last_end_diarized) - if speaker != previous_speaker: + if speaker != previous_speaker or not lines: lines.append( { "speaker": speaker, @@ -302,22 +314,53 @@ async def results_formatter(shared_state, websocket): lines[-1]["end"] = format_time(token.end) lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + # Update buffer_diarization with undiarized text + if undiarized_text: + combined_buffer_diarization = sep.join(undiarized_text) + if buffer_transcription: + combined_buffer_diarization += sep + await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) + buffer_diarization = combined_buffer_diarization + # Prepare response object - response = { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } + if lines: + response = { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + } + else: + response = { + "lines": [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(token.end) if token else format_time(0), + "diff": 0 + }], + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + + } - await websocket.send_json(response) + response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization + + if response_content != shared_state.last_response_content: + # Only send if there's actual content to send + if lines or buffer_transcription or buffer_diarization: + await websocket.send_json(response) + shared_state.last_response_content = response_content # Add a small delay to avoid overwhelming the client await asyncio.sleep(0.1) except Exception as e: logger.warning(f"Exception in results_formatter: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error ##### ENDPOINTS ##### @@ -422,6 +465,7 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") break logger.info("Exiting ffmpeg_stdout_reader...")