mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
- Updated README.md to include installation instructions for PyAudioWPatch. - Modified setup.py to add PyAudioWPatch as an optional dependency. - Enhanced audio_processor.py to initialize and handle PyAudioWPatch for system audio capture. - Updated basic_server.py to manage audio input modes and integrate PyAudioWPatch processing. - Refactored core.py to include audio input argument parsing.
196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from whisperlivekit import WhisperLiveKit, get_parsed_args
|
|
from whisperlivekit.audio_processor import AudioProcessor
|
|
|
|
import asyncio
|
|
import logging
|
|
import os, sys
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
logging.getLogger().setLevel(logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Starting up...")
|
|
kit = WhisperLiveKit()
|
|
app.state.kit = kit
|
|
logger.info(f"Audio Input mode: {kit.args.audio_input}")
|
|
|
|
audio_processor = AudioProcessor()
|
|
app.state.audio_processor = audio_processor
|
|
app.state.results_generator = None # Initialize
|
|
|
|
if kit.args.audio_input == "pyaudiowpatch":
|
|
logger.info("Starting PyAudioWPatch processing tasks...")
|
|
try:
|
|
app.state.results_generator = await audio_processor.create_tasks()
|
|
except Exception as e:
|
|
logger.critical(f"Failed to start PyAudioWPatch processing: {e}", exc_info=True)
|
|
else:
|
|
logger.info("WebSocket input mode selected. Processing will start on client connection.")
|
|
|
|
yield
|
|
|
|
logger.info("Shutting down...")
|
|
if hasattr(app.state, 'audio_processor') and app.state.audio_processor:
|
|
logger.info("Cleaning up AudioProcessor...")
|
|
await app.state.audio_processor.cleanup()
|
|
logger.info("Shutdown complete.")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def get():
|
|
return HTMLResponse(app.state.kit.web_interface())
|
|
|
|
|
|
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
|
try:
|
|
async for response in results_generator:
|
|
await websocket.send_json(response)
|
|
except Exception as e:
|
|
logger.warning(f"Error in WebSocket results handler: {e}")
|
|
|
|
|
|
@app.websocket("/asr")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection accepted.")
|
|
|
|
audio_processor = app.state.audio_processor
|
|
kit_args = app.state.kit.args
|
|
results_generator = None
|
|
websocket_task = None
|
|
receive_task = None
|
|
|
|
try:
|
|
if kit_args.audio_input == "websocket":
|
|
logger.info("WebSocket mode: Starting processing tasks for this connection.")
|
|
results_generator = await audio_processor.create_tasks()
|
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
|
|
|
async def receive_audio():
|
|
try:
|
|
while True:
|
|
message = await websocket.receive_bytes()
|
|
await audio_processor.process_audio(message)
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected by client (receive_audio).")
|
|
except Exception as e:
|
|
logger.error(f"Error receiving audio: {e}", exc_info=True)
|
|
finally:
|
|
logger.debug("Receive audio task finished.")
|
|
|
|
|
|
receive_task = asyncio.create_task(receive_audio())
|
|
done, pending = await asyncio.wait(
|
|
{websocket_task, receive_task},
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
for task in pending:
|
|
task.cancel() # Cancel the other task
|
|
|
|
elif kit_args.audio_input == "pyaudiowpatch":
|
|
logger.info("PyAudioWPatch mode: Streaming existing results.")
|
|
results_generator = app.state.results_generator
|
|
if results_generator is None:
|
|
logger.error("PyAudioWPatch results generator not available. Was startup successful?")
|
|
await websocket.close(code=1011, reason="Server error: Audio processing not started.")
|
|
return
|
|
|
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
|
await websocket_task
|
|
|
|
else:
|
|
logger.error(f"Unsupported audio input mode configured: {kit_args.audio_input}")
|
|
await websocket.close(code=1011, reason="Server configuration error.")
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected by client.")
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket endpoint: {e}", exc_info=True)
|
|
# Attempt to close gracefully
|
|
try:
|
|
await websocket.close(code=1011, reason=f"Server error: {e}")
|
|
except Exception:
|
|
pass # Ignore errors during close after another error
|
|
finally:
|
|
logger.info("Cleaning up WebSocket connection...")
|
|
if websocket_task and not websocket_task.done():
|
|
websocket_task.cancel()
|
|
if receive_task and not receive_task.done():
|
|
receive_task.cancel()
|
|
|
|
if kit_args.audio_input == "websocket":
|
|
pass
|
|
|
|
logger.info("WebSocket connection closed.")
|
|
|
|
def main():
|
|
"""Entry point for the CLI command."""
|
|
import uvicorn
|
|
|
|
# Get the globally parsed arguments
|
|
args = get_parsed_args()
|
|
|
|
# Set logger level based on args
|
|
log_level_name = args.log_level.upper()
|
|
# Ensure the level name is valid for the logging module
|
|
numeric_level = getattr(logging, log_level_name, None)
|
|
if not isinstance(numeric_level, int):
|
|
logging.warning(f"Invalid log level: {args.log_level}. Defaulting to INFO.")
|
|
numeric_level = logging.INFO
|
|
logging.getLogger().setLevel(numeric_level) # Set root logger level
|
|
# Set our specific logger level too
|
|
logger.setLevel(numeric_level)
|
|
logger.info(f"Log level set to: {log_level_name}")
|
|
|
|
# Determine uvicorn log level (map CRITICAL to critical, etc.)
|
|
uvicorn_log_level = log_level_name.lower()
|
|
if uvicorn_log_level == "debug": # Uvicorn uses 'trace' for more verbose than debug
|
|
uvicorn_log_level = "trace"
|
|
|
|
|
|
uvicorn_kwargs = {
|
|
"app": "whisperlivekit.basic_server:app",
|
|
"host":args.host,
|
|
"port":args.port,
|
|
"reload": False,
|
|
"log_level": uvicorn_log_level,
|
|
"lifespan": "on",
|
|
}
|
|
|
|
ssl_kwargs = {}
|
|
if args.ssl_certfile or args.ssl_keyfile:
|
|
if not (args.ssl_certfile and args.ssl_keyfile):
|
|
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
|
ssl_kwargs = {
|
|
"ssl_certfile": args.ssl_certfile,
|
|
"ssl_keyfile": args.ssl_keyfile
|
|
}
|
|
|
|
|
|
if ssl_kwargs:
|
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
|
|
|
uvicorn.run(**uvicorn_kwargs)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|