From 149d2ee44c3897225f9a73b42d412faf45693739 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 12 Feb 2025 05:53:55 +0100 Subject: [PATCH] Use lifespan to load the model just one --- whisper_fastapi_online_server.py | 56 ++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 3490ac3..0e9c6b4 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -4,6 +4,7 @@ import asyncio import numpy as np import ffmpeg from time import time +from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse @@ -11,15 +12,8 @@ from fastapi.middleware.cors import CORSMiddleware from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +##### LOAD ARGS ##### parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") parser.add_argument( @@ -49,28 +43,37 @@ parser.add_argument( add_shared_args(parser) args = parser.parse_args() -asr, tokenizer = backend_factory(args) - -if args.diarization: - from src.diarization.diarization_online import DiartDiarization - - -# Load demo HTML for the root endpoint -with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: - html = f.read() - - -@app.get("/") -async def get(): - return HTMLResponse(html) - - SAMPLE_RATE = 16000 CHANNELS = 1 SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size) BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE +if args.diarization: + from src.diarization.diarization_online import DiartDiarization + + +##### LOAD APP ##### + +@asynccontextmanager +async def lifespan(app: FastAPI): + global asr, tokenizer + asr, tokenizer = backend_factory(args) + yield + +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Load demo HTML for the root endpoint +with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: + html = f.read() async def start_ffmpeg_decoder(): """ @@ -91,6 +94,11 @@ async def start_ffmpeg_decoder(): return process +##### ENDPOINTS ##### + +@app.get("/") +async def get(): + return HTMLResponse(html) @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket):