diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 1c00751..4094d33 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -68,19 +68,23 @@ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz -if args.diarization: - from src.diarization.diarization_online import DiartDiarization ##### LOAD APP ##### @asynccontextmanager async def lifespan(app: FastAPI): - global asr, tokenizer + global asr, tokenizer, diarization if args.transcription: asr, tokenizer = backend_factory(args) else: asr, tokenizer = None, None + + if args.diarization: + from src.diarization.diarization_online import DiartDiarization + diarization = DiartDiarization(SAMPLE_RATE) + else : + diarization = None yield app = FastAPI(lifespan=lifespan) @@ -130,10 +134,10 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process = None pcm_buffer = bytearray() online = online_factory(args, asr, tokenizer) if args.transcription else None - diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None + async def restart_ffmpeg(): - nonlocal ffmpeg_process, online, diarization, pcm_buffer + nonlocal ffmpeg_process, online, pcm_buffer if ffmpeg_process: try: ffmpeg_process.kill() @@ -143,14 +147,12 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process = await start_ffmpeg_decoder() pcm_buffer = bytearray() online = online_factory(args, asr, tokenizer) if args.transcription else None - if args.diarization: - diarization = DiartDiarization(SAMPLE_RATE) logger.info("FFmpeg process started.") await restart_ffmpeg() async def ffmpeg_stdout_reader(): - nonlocal ffmpeg_process, online, diarization, pcm_buffer + nonlocal ffmpeg_process, online, pcm_buffer loop = asyncio.get_event_loop() full_transcription = "" beg = time()