diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 93468d1..0813496 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware -from whisper_online import asr_factory, add_shared_args +from whisper_online import backend_factory, online_factory, add_shared_args app = FastAPI() app.add_middleware( @@ -40,7 +40,7 @@ parser.add_argument( add_shared_args(parser) args = parser.parse_args() -asr, online = asr_factory(args) +asr, tokenizer = backend_factory(args) # Load demo HTML for the root endpoint with open("src/live_transcription.html", "r") as f: @@ -85,6 +85,9 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process = await start_ffmpeg_decoder() pcm_buffer = bytearray() + print("Loading online.") + online = online_factory(args, asr, tokenizer) + print("Online loaded.") # Continuously read decoded PCM from ffmpeg stdout in a background task async def ffmpeg_stdout_reader(): diff --git a/whisper_online.py b/whisper_online.py index b65c36a..8fadec1 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -920,11 +920,7 @@ def add_shared_args(parser): default="DEBUG", ) - -def asr_factory(args, logfile=sys.stderr): - """ - Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. - """ +def backend_factory(args): backend = args.backend if backend == "openai-api": logger.debug("Using OpenAI API.") @@ -967,10 +963,10 @@ def asr_factory(args, logfile=sys.stderr): tokenizer = create_tokenizer(tgt_language) else: tokenizer = None + return asr, tokenizer - # Create the OnlineASRProcessor +def online_factory(args, asr, tokenizer, logfile=sys.stderr): if args.vac: - online = VACOnlineASRProcessor( args.min_chunk_size, asr, @@ -985,10 +981,16 @@ def asr_factory(args, logfile=sys.stderr): logfile=logfile, buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), ) - + return online + +def asr_factory(args, logfile=sys.stderr): + """ + Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. + """ + asr, tokenizer = backend_factory(args) + online = online_factory(args, asr, tokenizer, logfile=logfile) return asr, online - def set_logging(args, logger, other="_server"): logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s logger.setLevel(args.log_level)