mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
129 lines
4.4 KiB
Python
129 lines
4.4 KiB
Python
import asyncio
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
from whisperlivekit import (AudioProcessor, TranscriptionEngine,
|
|
get_inline_ui_html, parse_args)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
logging.getLogger().setLevel(logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
config = parse_args()
|
|
transcription_engine = None
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global transcription_engine
|
|
transcription_engine = TranscriptionEngine(config=config)
|
|
yield
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/")
|
|
async def get():
|
|
return HTMLResponse(get_inline_ui_html())
|
|
|
|
|
|
async def handle_websocket_results(websocket, results_generator):
|
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
|
try:
|
|
async for response in results_generator:
|
|
await websocket.send_json(response.to_dict())
|
|
# when the results_generator finishes it means all audio has been processed
|
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
|
await websocket.send_json({"type": "ready_to_stop"})
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
|
except Exception as e:
|
|
logger.exception(f"Error in WebSocket results handler: {e}")
|
|
|
|
|
|
@app.websocket("/asr")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
global transcription_engine
|
|
audio_processor = AudioProcessor(
|
|
transcription_engine=transcription_engine,
|
|
)
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection opened.")
|
|
|
|
try:
|
|
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input)})
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send config to client: {e}")
|
|
|
|
results_generator = await audio_processor.create_tasks()
|
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
|
|
|
try:
|
|
while True:
|
|
message = await websocket.receive_bytes()
|
|
await audio_processor.process_audio(message)
|
|
except KeyError as e:
|
|
if 'bytes' in str(e):
|
|
logger.warning(f"Client has closed the connection.")
|
|
else:
|
|
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected by client during message receiving loop.")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
|
|
finally:
|
|
logger.info("Cleaning up WebSocket endpoint...")
|
|
if not websocket_task.done():
|
|
websocket_task.cancel()
|
|
try:
|
|
await websocket_task
|
|
except asyncio.CancelledError:
|
|
logger.info("WebSocket results handler task was cancelled.")
|
|
except Exception as e:
|
|
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
|
|
|
await audio_processor.cleanup()
|
|
logger.info("WebSocket endpoint cleaned up successfully.")
|
|
|
|
def main():
|
|
"""Entry point for the CLI command."""
|
|
import uvicorn
|
|
|
|
uvicorn_kwargs = {
|
|
"app": "whisperlivekit.basic_server:app",
|
|
"host": config.host,
|
|
"port": config.port,
|
|
"reload": False,
|
|
"log_level": "info",
|
|
"lifespan": "on",
|
|
}
|
|
|
|
ssl_kwargs = {}
|
|
if config.ssl_certfile or config.ssl_keyfile:
|
|
if not (config.ssl_certfile and config.ssl_keyfile):
|
|
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
|
ssl_kwargs = {
|
|
"ssl_certfile": config.ssl_certfile,
|
|
"ssl_keyfile": config.ssl_keyfile,
|
|
}
|
|
|
|
if ssl_kwargs:
|
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
|
if config.forwarded_allow_ips:
|
|
uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips}
|
|
|
|
uvicorn.run(**uvicorn_kwargs)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|