import io import argparse import asyncio import numpy as np import ffmpeg from time import time from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args import subprocess import math ##### LOAD ARGS ##### parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") parser.add_argument( "--host", type=str, default="localhost", help="The host address to bind the server to.", ) parser.add_argument( "--port", type=int, default=8000, help="The port number to bind the server to." ) parser.add_argument( "--warmup-file", type=str, dest="warmup_file", help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .", ) parser.add_argument( "--diarization", type=bool, default=False, help="Whether to enable speaker diarization.", ) add_shared_args(parser) args = parser.parse_args() SAMPLE_RATE = 16000 CHANNELS = 1 SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size) BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE if args.diarization: from src.diarization.diarization_online import DiartDiarization ##### LOAD APP ##### @asynccontextmanager async def lifespan(app: FastAPI): global asr, tokenizer asr, tokenizer = backend_factory(args) yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load demo HTML for the root endpoint with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: html = f.read() async def start_ffmpeg_decoder(): """ Start an FFmpeg process in async streaming mode that reads WebM from stdin and outputs raw s16le PCM on stdout. Returns the process object. """ process = ( ffmpeg.input("pipe:0", format="webm") .output( "pipe:1", format="s16le", acodec="pcm_s16le", ac=CHANNELS, ar=str(SAMPLE_RATE), ) .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) ) return process ##### ENDPOINTS ##### @app.get("/") async def get(): return HTMLResponse(html) @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() print("WebSocket connection opened.") ffmpeg_process = None pcm_buffer = bytearray() online = online_factory(args, asr, tokenizer) diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None async def restart_ffmpeg(): nonlocal ffmpeg_process, online, diarization, pcm_buffer if ffmpeg_process: try: ffmpeg_process.kill() await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) except Exception as e: print(f"Error killing FFmpeg process: {e}") ffmpeg_process = await start_ffmpeg_decoder() pcm_buffer = bytearray() online = online_factory(args, asr, tokenizer) if args.diarization: diarization = DiartDiarization(SAMPLE_RATE) print("FFmpeg process started.") await restart_ffmpeg() async def ffmpeg_stdout_reader(): nonlocal ffmpeg_process, online, diarization, pcm_buffer loop = asyncio.get_event_loop() full_transcription = "" beg = time() chunk_history = [] # Will store dicts: {beg, end, text, speaker} while True: try: elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) beg = time() # Read chunk with timeout try: chunk = await asyncio.wait_for( loop.run_in_executor( None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration ), timeout=5.0 ) except asyncio.TimeoutError: print("FFmpeg read timeout. Restarting...") await restart_ffmpeg() full_transcription = "" chunk_history = [] beg = time() continue # Skip processing and read from new process if not chunk: print("FFmpeg stdout closed.") break pcm_buffer.extend(chunk) if len(pcm_buffer) >= BYTES_PER_SEC: # Convert int16 -> float32 pcm_array = ( np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 ) pcm_buffer = bytearray() online.insert_audio_chunk(pcm_array) transcription = online.process_iter() if transcription: chunk_history.append({ "beg": transcription.start, "end": transcription.end, "text": transcription.text, "speaker": "0" }) full_transcription += transcription.text if transcription else "" buffer = online.get_buffer() if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed buffer = "" lines = [ { "speaker": "0", "text": "", } ] if args.diarization: await diarization.diarize(pcm_array) diarization.assign_speakers_to_chunks(chunk_history) for ch in chunk_history: if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]: lines.append( { "speaker": ch["speaker"][-1], "text": ch['text'] } ) else: lines[-1]["text"] += ch['text'] response = {"lines": lines, "buffer": buffer} await websocket.send_json(response) except Exception as e: print(f"Exception in ffmpeg_stdout_reader: {e}") break print("Exiting ffmpeg_stdout_reader...") stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) try: while True: # Receive incoming WebM audio chunks from the client message = await websocket.receive_bytes() try: ffmpeg_process.stdin.write(message) ffmpeg_process.stdin.flush() except (BrokenPipeError, AttributeError) as e: print(f"Error writing to FFmpeg: {e}. Restarting...") await restart_ffmpeg() ffmpeg_process.stdin.write(message) ffmpeg_process.stdin.flush() except WebSocketDisconnect: print("WebSocket disconnected.") finally: stdout_reader_task.cancel() try: ffmpeg_process.stdin.close() ffmpeg_process.wait() except: pass if args.diarization: diarization.close() if __name__ == "__main__": import uvicorn uvicorn.run( "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True, log_level="info" )