diff --git a/audio.py b/audio.py index ee6ca56..2436685 100644 --- a/audio.py +++ b/audio.py @@ -54,7 +54,7 @@ class AudioProcessor: / 32768.0) return pcm_array - async def start_ffmpeg_decoder(self): + def start_ffmpeg_decoder(self): """ Start an FFmpeg process in async streaming mode that reads WebM from stdin and outputs raw s16le PCM on stdout. Returns the process object. @@ -79,7 +79,7 @@ class AudioProcessor: await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) except Exception as e: logger.warning(f"Error killing FFmpeg process: {e}") - self.ffmpeg_process = await self.start_ffmpeg_decoder() + self.ffmpeg_process = self.start_ffmpeg_decoder() self.pcm_buffer = bytearray() async def ffmpeg_stdout_reader(self): @@ -198,10 +198,9 @@ class AudioProcessor: finally: self.diarization_queue.task_done() - async def results_formatter(self, websocket): + async def results_formatter(self): while True: try: - # Get the current state state = await self.shared_state.get_current_state() tokens = state["tokens"] buffer_transcription = state["buffer_transcription"] @@ -217,7 +216,6 @@ class AudioProcessor: sleep(0.5) state = await self.shared_state.get_current_state() tokens = state["tokens"] - # Process tokens to create response previous_speaker = -1 lines = [] last_end_diarized = 0 @@ -273,22 +271,21 @@ class AudioProcessor: "beg": format_time(0), "end": format_time(tokens[-1].end) if tokens else format_time(0), "diff": 0 - }], + }], "buffer_transcription": buffer_transcription, "buffer_diarization": buffer_diarization, "remaining_time_transcription": remaining_time_transcription, "remaining_time_diarization": remaining_time_diarization - } response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization if response_content != self.shared_state.last_response_content: if lines or buffer_transcription or buffer_diarization: - await websocket.send_json(response) + yield response self.shared_state.last_response_content = response_content - # Add a small delay to avoid overwhelming the client + #small delay to avoid overwhelming the client await asyncio.sleep(0.1) except Exception as e: @@ -296,18 +293,22 @@ class AudioProcessor: logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error - async def create_tasks(self, websocket, diarization): + async def create_tasks(self, diarization=None): + if diarization: + self.diarization = diarization + tasks = [] if self.args.transcription and self.online: tasks.append(asyncio.create_task(self.transcription_processor())) - if self.args.diarization and diarization: - tasks.append(asyncio.create_task(self.diarization_processor(diarization))) - formatter_task = asyncio.create_task(self.results_formatter(websocket)) - tasks.append(formatter_task) + if self.args.diarization and self.diarization: + tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) + stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) tasks.append(stdout_reader_task) + self.tasks = tasks - self.diarization = diarization + + return self.results_formatter() async def cleanup(self): for task in self.tasks: diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index b0ca658..cf81ff8 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -5,6 +5,7 @@ from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr +import asyncio import logging from parse_args import parse_args from audio import AudioProcessor @@ -51,6 +52,16 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f: async def get(): return HTMLResponse(html) + +async def handle_websocket_results(websocket, results_generator): + """Consumes results from the audio processor and sends them via WebSocket.""" + try: + async for response in results_generator: + await websocket.send_json(response) + except Exception as e: + logger.warning(f"Error in WebSocket results handler: {e}") + + @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): audio_processor = AudioProcessor(args, asr, tokenizer) @@ -58,14 +69,17 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() logger.info("WebSocket connection opened.") - await audio_processor.create_tasks(websocket, diarization) + results_generator = await audio_processor.create_tasks(diarization) + websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) + try: while True: message = await websocket.receive_bytes() - audio_processor.process_audio(message) + await audio_processor.process_audio(message) except WebSocketDisconnect: logger.warning("WebSocket disconnected.") finally: + websocket_task.cancel() audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.")