Files
WhisperLiveKit/whisperlivekit/basic_server.py
2026-03-14 00:13:29 +01:00

372 lines
13 KiB
Python

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import List, Optional
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from whisperlivekit import AudioProcessor, TranscriptionEngine, get_inline_ui_html, parse_args
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.getLogger("whisperlivekit.qwen3_asr").setLevel(logging.DEBUG)
config = parse_args()
transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global transcription_engine
transcription_engine = TranscriptionEngine(config=config)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def get():
return HTMLResponse(get_inline_ui_html())
@app.get("/health")
async def health():
"""Health check endpoint."""
global transcription_engine
backend = getattr(transcription_engine.config, "backend", "whisper") if transcription_engine else None
return JSONResponse({
"status": "ok",
"backend": backend,
"ready": transcription_engine is not None,
})
async def handle_websocket_results(websocket, results_generator, diff_tracker=None):
"""Consumes results from the audio processor and sends them via WebSocket."""
try:
async for response in results_generator:
if diff_tracker is not None:
await websocket.send_json(diff_tracker.to_message(response))
else:
await websocket.send_json(response.to_dict())
# when the results_generator finishes it means all audio has been processed
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
except Exception as e:
logger.exception(f"Error in WebSocket results handler: {e}")
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
global transcription_engine
# Read per-session options from query parameters
session_language = websocket.query_params.get("language", None)
mode = websocket.query_params.get("mode", "full")
audio_processor = AudioProcessor(
transcription_engine=transcription_engine,
language=session_language,
)
await websocket.accept()
logger.info(
"WebSocket connection opened.%s",
f" language={session_language}" if session_language else "",
)
diff_tracker = None
if mode == "diff":
from whisperlivekit.diff_protocol import DiffTracker
diff_tracker = DiffTracker()
logger.info("Client requested diff mode")
try:
await websocket.send_json({"type": "config", "useAudioWorklet": bool(config.pcm_input), "mode": mode})
except Exception as e:
logger.warning(f"Failed to send config to client: {e}")
results_generator = await audio_processor.create_tasks()
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator, diff_tracker))
try:
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
except KeyError as e:
if 'bytes' in str(e):
logger.warning("Client has closed the connection.")
else:
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
except WebSocketDisconnect:
logger.info("WebSocket disconnected by client during message receiving loop.")
except Exception as e:
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
finally:
logger.info("Cleaning up WebSocket endpoint...")
if not websocket_task.done():
websocket_task.cancel()
try:
await websocket_task
except asyncio.CancelledError:
logger.info("WebSocket results handler task was cancelled.")
except Exception as e:
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
await audio_processor.cleanup()
logger.info("WebSocket endpoint cleaned up successfully.")
# ---------------------------------------------------------------------------
# Deepgram-compatible WebSocket API (/v1/listen)
# ---------------------------------------------------------------------------
@app.websocket("/v1/listen")
async def deepgram_websocket_endpoint(websocket: WebSocket):
"""Deepgram-compatible live transcription WebSocket."""
global transcription_engine
from whisperlivekit.deepgram_compat import handle_deepgram_websocket
await handle_deepgram_websocket(websocket, transcription_engine, config)
# ---------------------------------------------------------------------------
# OpenAI-compatible REST API (/v1/audio/transcriptions)
# ---------------------------------------------------------------------------
async def _convert_to_pcm(audio_bytes: bytes) -> bytes:
"""Convert any audio format to PCM s16le mono 16kHz using ffmpeg."""
proc = await asyncio.create_subprocess_exec(
"ffmpeg", "-i", "pipe:0",
"-f", "s16le", "-acodec", "pcm_s16le",
"-ar", "16000", "-ac", "1",
"-loglevel", "error",
"pipe:1",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(input=audio_bytes)
if proc.returncode != 0:
from fastapi import HTTPException
raise HTTPException(status_code=400, detail=f"Audio conversion failed: {stderr.decode().strip()}")
return stdout
def _parse_time_str(time_str: str) -> float:
"""Parse 'H:MM:SS.cc' to seconds."""
parts = time_str.split(":")
if len(parts) == 3:
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
if len(parts) == 2:
return int(parts[0]) * 60 + float(parts[1])
return float(parts[0])
def _format_openai_response(front_data, response_format: str, language: Optional[str], duration: float) -> dict:
"""Convert FrontData to OpenAI-compatible response."""
d = front_data.to_dict()
lines = d.get("lines", [])
# Combine all speech text (exclude silence segments)
text_parts = [l["text"] for l in lines if l.get("text") and l.get("speaker", 0) != -2]
full_text = " ".join(text_parts).strip()
if response_format == "text":
return full_text
# Build segments and words for verbose_json
segments = []
words = []
for i, line in enumerate(lines):
if line.get("speaker") == -2 or not line.get("text"):
continue
start = _parse_time_str(line.get("start", "0:00:00"))
end = _parse_time_str(line.get("end", "0:00:00"))
segments.append({
"id": len(segments),
"start": round(start, 2),
"end": round(end, 2),
"text": line["text"],
})
# Split segment text into approximate words with estimated timestamps
seg_words = line["text"].split()
if seg_words:
word_duration = (end - start) / max(len(seg_words), 1)
for j, word in enumerate(seg_words):
words.append({
"word": word,
"start": round(start + j * word_duration, 2),
"end": round(start + (j + 1) * word_duration, 2),
})
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "unknown",
"duration": round(duration, 2),
"text": full_text,
"words": words,
"segments": segments,
}
if response_format in ("srt", "vtt"):
lines_out = []
if response_format == "vtt":
lines_out.append("WEBVTT\n")
for i, seg in enumerate(segments):
start_ts = _srt_timestamp(seg["start"], response_format)
end_ts = _srt_timestamp(seg["end"], response_format)
if response_format == "srt":
lines_out.append(f"{i + 1}")
lines_out.append(f"{start_ts} --> {end_ts}")
lines_out.append(seg["text"])
lines_out.append("")
return "\n".join(lines_out)
# Default: json
return {"text": full_text}
def _srt_timestamp(seconds: float, fmt: str) -> str:
"""Format seconds as SRT (HH:MM:SS,mmm) or VTT (HH:MM:SS.mmm) timestamp."""
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
ms = int(round((seconds % 1) * 1000))
sep = "," if fmt == "srt" else "."
return f"{h:02d}:{m:02d}:{s:02d}{sep}{ms:03d}"
@app.post("/v1/audio/transcriptions")
async def create_transcription(
file: UploadFile = File(...),
model: str = Form(default=""),
language: Optional[str] = Form(default=None),
prompt: str = Form(default=""),
response_format: str = Form(default="json"),
timestamp_granularities: Optional[List[str]] = Form(default=None),
):
"""OpenAI-compatible audio transcription endpoint.
Accepts the same parameters as OpenAI's /v1/audio/transcriptions API.
The `model` parameter is accepted but ignored (uses the server's configured backend).
"""
global transcription_engine
audio_bytes = await file.read()
if not audio_bytes:
from fastapi import HTTPException
raise HTTPException(status_code=400, detail="Empty audio file")
# Convert to PCM for pipeline processing
pcm_data = await _convert_to_pcm(audio_bytes)
duration = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit
# Process through the full pipeline
processor = AudioProcessor(
transcription_engine=transcription_engine,
language=language,
)
# Force PCM input regardless of server config
processor.is_pcm_input = True
results_gen = await processor.create_tasks()
# Collect results in background while feeding audio
final_result = None
async def collect():
nonlocal final_result
async for result in results_gen:
final_result = result
collect_task = asyncio.create_task(collect())
# Feed audio in chunks (1 second each)
chunk_size = 16000 * 2 # 1 second of PCM
for i in range(0, len(pcm_data), chunk_size):
await processor.process_audio(pcm_data[i:i + chunk_size])
# Signal end of audio
await processor.process_audio(b"")
# Wait for pipeline to finish
try:
await asyncio.wait_for(collect_task, timeout=120.0)
except asyncio.TimeoutError:
logger.warning("Transcription timed out after 120s")
finally:
await processor.cleanup()
if final_result is None:
return JSONResponse({"text": ""})
result = _format_openai_response(final_result, response_format, language, duration)
if isinstance(result, str):
return PlainTextResponse(result)
return JSONResponse(result)
@app.get("/v1/models")
async def list_models():
"""OpenAI-compatible model listing endpoint."""
global transcription_engine
backend = getattr(transcription_engine.config, "backend", "whisper") if transcription_engine else "whisper"
model_size = getattr(transcription_engine.config, "model_size", "base") if transcription_engine else "base"
return JSONResponse({
"object": "list",
"data": [{
"id": f"{backend}/{model_size}" if backend != "whisper" else f"whisper-{model_size}",
"object": "model",
"owned_by": "whisperlivekit",
}],
})
def main():
"""Entry point for the CLI command."""
import uvicorn
from whisperlivekit.cli import print_banner
ssl = bool(config.ssl_certfile and config.ssl_keyfile)
print_banner(config, config.host, config.port, ssl=ssl)
uvicorn_kwargs = {
"app": "whisperlivekit.basic_server:app",
"host": config.host,
"port": config.port,
"reload": False,
"log_level": "info",
"lifespan": "on",
}
ssl_kwargs = {}
if config.ssl_certfile or config.ssl_keyfile:
if not (config.ssl_certfile and config.ssl_keyfile):
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
ssl_kwargs = {
"ssl_certfile": config.ssl_certfile,
"ssl_keyfile": config.ssl_keyfile,
}
if ssl_kwargs:
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
if config.forwarded_allow_ips:
uvicorn_kwargs = {**uvicorn_kwargs, "forwarded_allow_ips": config.forwarded_allow_ips}
uvicorn.run(**uvicorn_kwargs)
if __name__ == "__main__":
main()