mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
126 lines
4.2 KiB
Python
126 lines
4.2 KiB
Python
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
|
import asyncio
|
|
import logging
|
|
from starlette.staticfiles import StaticFiles
|
|
import pathlib
|
|
import whisperlivekit.web as webpkg
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
logging.getLogger().setLevel(logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
args = parse_args()
|
|
transcription_engine = None
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global transcription_engine
|
|
transcription_engine = TranscriptionEngine(
|
|
**vars(args),
|
|
)
|
|
yield
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
web_dir = pathlib.Path(webpkg.__file__).parent
|
|
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
|
|
|
@app.get("/")
|
|
async def get():
|
|
return HTMLResponse(get_web_interface_html())
|
|
|
|
|
|
async def handle_websocket_results(websocket, results_generator):
|
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
|
try:
|
|
async for response in results_generator:
|
|
await websocket.send_json(response)
|
|
# when the results_generator finishes it means all audio has been processed
|
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
|
await websocket.send_json({"type": "ready_to_stop"})
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
|
except Exception as e:
|
|
logger.exception(f"Error in WebSocket results handler: {e}")
|
|
|
|
|
|
@app.websocket("/asr")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
global transcription_engine
|
|
audio_processor = AudioProcessor(
|
|
transcription_engine=transcription_engine,
|
|
)
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection opened.")
|
|
|
|
results_generator = await audio_processor.create_tasks()
|
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
|
|
|
try:
|
|
while True:
|
|
message = await websocket.receive_bytes()
|
|
await audio_processor.process_audio(message)
|
|
except KeyError as e:
|
|
if 'bytes' in str(e):
|
|
logger.warning(f"Client has closed the connection.")
|
|
else:
|
|
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected by client during message receiving loop.")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
|
|
finally:
|
|
logger.info("Cleaning up WebSocket endpoint...")
|
|
if not websocket_task.done():
|
|
websocket_task.cancel()
|
|
try:
|
|
await websocket_task
|
|
except asyncio.CancelledError:
|
|
logger.info("WebSocket results handler task was cancelled.")
|
|
except Exception as e:
|
|
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
|
|
|
await audio_processor.cleanup()
|
|
logger.info("WebSocket endpoint cleaned up successfully.")
|
|
|
|
def main():
|
|
"""Entry point for the CLI command."""
|
|
import uvicorn
|
|
|
|
uvicorn_kwargs = {
|
|
"app": "whisperlivekit.basic_server:app",
|
|
"host":args.host,
|
|
"port":args.port,
|
|
"reload": False,
|
|
"log_level": "info",
|
|
"lifespan": "on",
|
|
}
|
|
|
|
ssl_kwargs = {}
|
|
if args.ssl_certfile or args.ssl_keyfile:
|
|
if not (args.ssl_certfile and args.ssl_keyfile):
|
|
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
|
ssl_kwargs = {
|
|
"ssl_certfile": args.ssl_certfile,
|
|
"ssl_keyfile": args.ssl_keyfile
|
|
}
|
|
|
|
if ssl_kwargs:
|
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
|
|
|
uvicorn.run(**uvicorn_kwargs)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|