From 5ca65e21b725e102958deb6ded37f6bb04b928d7 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Mar 2025 10:33:22 +0100 Subject: [PATCH] Refactor DiartDiarization initialization and streamline WebSocket audio processing --- audio.py | 136 ++++++++++++++++++------------ diarization/diarization_online.py | 2 +- whisper_fastapi_online_server.py | 80 ++---------------- 3 files changed, 89 insertions(+), 129 deletions(-) diff --git a/audio.py b/audio.py index 1b97746..ee6ca56 100644 --- a/audio.py +++ b/audio.py @@ -1,25 +1,15 @@ -import io -import argparse import asyncio import numpy as np import ffmpeg from time import time, sleep -from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse -from fastapi.middleware.cors import CORSMiddleware - -from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr -from timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import online_factory import math import logging -from datetime import timedelta import traceback from state import SharedState from formatters import format_time -from parse_args import parse_args logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - class AudioProcessor: def __init__(self, args, asr, tokenizer): @@ -38,9 +27,22 @@ class AudioProcessor: self.bytes_per_sample = 2 self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + + self.shared_state = SharedState() self.asr = asr self.tokenizer = tokenizer + + self.ffmpeg_process = self.start_ffmpeg_decoder() + + self.transcription_queue = asyncio.Queue() if self.args.transcription else None + self.diarization_queue = asyncio.Queue() if self.args.diarization else None + + self.pcm_buffer = bytearray() + if self.args.transcription: + self.online = online_factory(self.args, self.asr, self.tokenizer) + + def convert_pcm_to_float(self, pcm_buffer): """ @@ -70,26 +72,17 @@ class AudioProcessor: ) return process - async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer): - if ffmpeg_process: + async def restart_ffmpeg(self): + if self.ffmpeg_process: try: - ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) + self.ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) except Exception as e: logger.warning(f"Error killing FFmpeg process: {e}") - ffmpeg_process = await self.start_ffmpeg_decoder() - pcm_buffer = bytearray() - - if self.args.transcription: - online = online_factory(self.args, self.asr, self.tokenizer) - - await self.shared_state.reset() - logger.info("FFmpeg process started.") - return ffmpeg_process, online, pcm_buffer + self.ffmpeg_process = await self.start_ffmpeg_decoder() + self.pcm_buffer = bytearray() - - - async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue): + async def ffmpeg_stdout_reader(self): loop = asyncio.get_event_loop() beg = time() @@ -103,36 +96,36 @@ class AudioProcessor: try: chunk = await asyncio.wait_for( loop.run_in_executor( - None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration + None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration ), timeout=15.0 ) except asyncio.TimeoutError: logger.warning("FFmpeg read timeout. Restarting...") - ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) + await self.restart_ffmpeg() beg = time() continue # Skip processing and read from new process if not chunk: logger.info("FFmpeg stdout closed.") break - pcm_buffer.extend(chunk) + self.pcm_buffer.extend(chunk) - if self.args.diarization and diarization_queue: - await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy()) + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy()) - if len(pcm_buffer) >= self.bytes_per_sec: - if len(pcm_buffer) > self.max_bytes_per_sec: + if len(self.pcm_buffer) >= self.bytes_per_sec: + if len(self.pcm_buffer) > self.max_bytes_per_sec: logger.warning( - f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds. + f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds. The model probably struggles to keep up. Consider using a smaller model. """) - pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec]) - pcm_buffer = pcm_buffer[self.max_bytes_per_sec:] + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) + self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] - if self.args.transcription and transcription_queue: - await transcription_queue.put(pcm_array.copy()) + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) if not self.args.transcription and not self.args.diarization: @@ -144,27 +137,24 @@ class AudioProcessor: break logger.info("Exiting ffmpeg_stdout_reader...") - - - - async def transcription_processor(self, pcm_queue, online): + async def transcription_processor(self): full_transcription = "" - sep = online.asr.sep + sep = self.online.asr.sep while True: try: - pcm_array = await pcm_queue.get() + pcm_array = await self.transcription_queue.get() - logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") + logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.") # Process transcription - online.insert_audio_chunk(pcm_array) - new_tokens = online.process_iter() + self.online.insert_audio_chunk(pcm_array) + new_tokens = self.online.process_iter() if new_tokens: full_transcription += sep.join([t.text for t in new_tokens]) - _buffer = online.get_buffer() + _buffer = self.online.get_buffer() buffer = _buffer.text end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) @@ -178,14 +168,15 @@ class AudioProcessor: logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: - pcm_queue.task_done() + self.transcription_queue.task_done() - async def diarization_processor(self, pcm_queue, diarization_obj): + + async def diarization_processor(self, diarization_obj): buffer_diarization = "" while True: try: - pcm_array = await pcm_queue.get() + pcm_array = await self.diarization_queue.get() # Process diarization await diarization_obj.diarize(pcm_array) @@ -205,7 +196,7 @@ class AudioProcessor: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: - pcm_queue.task_done() + self.diarization_queue.task_done() async def results_formatter(self, websocket): while True: @@ -304,3 +295,40 @@ class AudioProcessor: logger.warning(f"Exception in results_formatter: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error + + async def create_tasks(self, websocket, diarization): + tasks = [] + if self.args.transcription and self.online: + tasks.append(asyncio.create_task(self.transcription_processor())) + if self.args.diarization and diarization: + tasks.append(asyncio.create_task(self.diarization_processor(diarization))) + formatter_task = asyncio.create_task(self.results_formatter(websocket)) + tasks.append(formatter_task) + stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) + tasks.append(stdout_reader_task) + self.tasks = tasks + self.diarization = diarization + + async def cleanup(self): + for task in self.tasks: + task.cancel() + try: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.ffmpeg_process.stdin.close() + self.ffmpeg_process.wait() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + if self.args.diarization and self.diarization: + self.diarization.close() + + async def process_audio(self, message): + try: + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + except (BrokenPipeError, AttributeError) as e: + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") + await self.restart_ffmpeg() + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + + \ No newline at end of file diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py index 45bec13..622fb15 100644 --- a/diarization/diarization_online.py +++ b/diarization/diarization_online.py @@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource): class DiartDiarization: - def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): + def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 99eb4d4..b0ca658 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -1,24 +1,11 @@ -import io -import argparse -import asyncio -import numpy as np -import ffmpeg -from time import time, sleep from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware -from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr -from timed_objects import ASRToken - -import math +from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr import logging -from datetime import timedelta -import traceback -from state import SharedState -from formatters import format_time from parse_args import parse_args from audio import AudioProcessor @@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - - args = parse_args() -SAMPLE_RATE = 16000 -# CHANNELS = 1 -# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size) -# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample -# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE -# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz - - -##### LOAD APP ##### @asynccontextmanager async def lifespan(app: FastAPI): @@ -52,7 +28,7 @@ async def lifespan(app: FastAPI): if args.diarization: from diarization.diarization_online import DiartDiarization - diarization = DiartDiarization(SAMPLE_RATE) + diarization = DiartDiarization() else : diarization = None yield @@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f: async def get(): return HTMLResponse(html) - - - - - - - @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): audio_processor = AudioProcessor(args, asr, tokenizer) await websocket.accept() logger.info("WebSocket connection opened.") - - ffmpeg_process = None - pcm_buffer = bytearray() - - transcription_queue = asyncio.Queue() if args.transcription else None - diarization_queue = asyncio.Queue() if args.diarization else None - - online = None - - ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) - tasks = [] - if args.transcription and online: - tasks.append(asyncio.create_task( - audio_processor.transcription_processor(transcription_queue, online))) - if args.diarization and diarization: - tasks.append(asyncio.create_task( - audio_processor.diarization_processor(diarization_queue, diarization))) - formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket)) - tasks.append(formatter_task) - stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue)) - tasks.append(stdout_reader_task) - + + await audio_processor.create_tasks(websocket, diarization) try: while True: - # Receive incoming WebM audio chunks from the client message = await websocket.receive_bytes() - try: - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() - except (BrokenPipeError, AttributeError) as e: - logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() + audio_processor.process_audio(message) except WebSocketDisconnect: logger.warning("WebSocket disconnected.") finally: - for task in tasks: - task.cancel() - try: - await asyncio.gather(*tasks, return_exceptions=True) - ffmpeg_process.stdin.close() - ffmpeg_process.wait() - except Exception as e: - logger.warning(f"Error during cleanup: {e}") - if args.diarization and diarization: - diarization.close() + audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__":