diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index c8b8436..89b010c 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -14,8 +14,14 @@ from src.whisper_streaming.whisper_online import backend_factory, online_factory import subprocess import math +import logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.getLogger().setLevel(logging.WARNING) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + ##### LOAD ARGS ##### parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") @@ -51,6 +57,7 @@ CHANNELS = 1 SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size) BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE +MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz if args.diarization: from src.diarization.diarization_online import DiartDiarization @@ -106,7 +113,7 @@ async def get(): @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - print("WebSocket connection opened.") + logger.info("WebSocket connection opened.") ffmpeg_process = None pcm_buffer = bytearray() @@ -120,13 +127,13 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process.kill() await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) except Exception as e: - print(f"Error killing FFmpeg process: {e}") + logger.warning(f"Error killing FFmpeg process: {e}") ffmpeg_process = await start_ffmpeg_decoder() pcm_buffer = bytearray() online = online_factory(args, asr, tokenizer) if args.diarization: diarization = DiartDiarization(SAMPLE_RATE) - print("FFmpeg process started.") + logger.info("FFmpeg process started.") await restart_ffmpeg() @@ -153,7 +160,7 @@ async def websocket_endpoint(websocket: WebSocket): timeout=5.0 ) except asyncio.TimeoutError: - print("FFmpeg read timeout. Restarting...") + logger.warning("FFmpeg read timeout. Restarting...") await restart_ffmpeg() full_transcription = "" chunk_history = [] @@ -161,17 +168,23 @@ async def websocket_endpoint(websocket: WebSocket): continue # Skip processing and read from new process if not chunk: - print("FFmpeg stdout closed.") + logger.info("FFmpeg stdout closed.") break pcm_buffer.extend(chunk) if len(pcm_buffer) >= BYTES_PER_SEC: + if len(pcm_buffer) > MAX_BYTES_PER_SEC: + logger.warning( + f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds. + The model probably struggles to keep up. Consider using a smaller model. + """) # Convert int16 -> float32 pcm_array = ( - np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) + np.frombuffer(pcm_buffer[:MAX_BYTES_PER_SEC], dtype=np.int16).astype(np.float32) / 32768.0 ) - pcm_buffer = bytearray() + pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:] + logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") online.insert_audio_chunk(pcm_array) transcription = online.process_iter() @@ -215,10 +228,10 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(response) except Exception as e: - print(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") break - print("Exiting ffmpeg_stdout_reader...") + logger.info("Exiting ffmpeg_stdout_reader...") stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) @@ -230,12 +243,12 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process.stdin.write(message) ffmpeg_process.stdin.flush() except (BrokenPipeError, AttributeError) as e: - print(f"Error writing to FFmpeg: {e}. Restarting...") + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") await restart_ffmpeg() ffmpeg_process.stdin.write(message) ffmpeg_process.stdin.flush() except WebSocketDisconnect: - print("WebSocket disconnected.") + logger.warning("WebSocket disconnected.") finally: stdout_reader_task.cancel() try: @@ -254,4 +267,4 @@ if __name__ == "__main__": uvicorn.run( "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True, log_level="info" - ) + ) \ No newline at end of file diff --git a/whisper_noserver_test.py b/whisper_noserver_test.py deleted file mode 100644 index c36ec7b..0000000 --- a/whisper_noserver_test.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 -import sys -import numpy as np -import librosa -from functools import lru_cache -import time -import logging - -logger = logging.getLogger(__name__) - -from src.whisper_streaming.whisper_online import * - -@lru_cache(10**6) -def load_audio(fname): - a, _ = librosa.load(fname, sr=16000, dtype=np.float32) - return a - - -def load_audio_chunk(fname, beg, end): - audio = load_audio(fname) - beg_s = int(beg * 16000) - end_s = int(end * 16000) - return audio[beg_s:end_s] - -if __name__ == "__main__": - - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--audio_path", - type=str, - default='samples_jfk.wav', - help="Filename of 16kHz mono channel wav, on which live streaming is simulated.", - ) - add_shared_args(parser) - parser.add_argument( - "--start_at", - type=float, - default=0.0, - help="Start processing audio at this time.", - ) - parser.add_argument( - "--offline", action="store_true", default=False, help="Offline mode." - ) - parser.add_argument( - "--comp_unaware", - action="store_true", - default=False, - help="Computationally unaware simulation.", - ) - - args = parser.parse_args() - - # reset to store stderr to different file stream, e.g. open(os.devnull,"w") - logfile = None # sys.stderr - - if args.offline and args.comp_unaware: - logger.error( - "No or one option from --offline and --comp_unaware are available, not both. Exiting." - ) - sys.exit(1) - - # if args.log_level: - # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s', - # level=getattr(logging, args.log_level)) - - set_logging(args, logger,others=["src.whisper_streaming.online_asr"]) - - audio_path = args.audio_path - - SAMPLING_RATE = 16000 - duration = len(load_audio(audio_path)) / SAMPLING_RATE - logger.info("Audio duration is: %2.2f seconds" % duration) - - asr, online = asr_factory(args, logfile=logfile) - if args.vac: - min_chunk = args.vac_chunk_size - else: - min_chunk = args.min_chunk_size - - # load the audio into the LRU cache before we start the timer - a = load_audio_chunk(audio_path, 0, 1) - - # warm up the ASR because the very first transcribe takes much more time than the other - asr.transcribe(a) - - beg = args.start_at - start = time.time() - beg - - def output_transcript(o, now=None): - # output format in stdout is like: - # 4186.3606 0 1720 Takhle to je - # - the first three words are: - # - emission time from beginning of processing, in milliseconds - # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway - # - the next words: segment transcript - if now is None: - now = time.time() - start - if o[0] is not None: - log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}" - - logger.debug( - log_string - ) - - if logfile is not None: - print( - log_string, - file=logfile, - flush=True, - ) - else: - # No text, so no output - pass - - if args.offline: ## offline mode processing (for testing/debugging) - a = load_audio(audio_path) - online.insert_audio_chunk(a) - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {repr(e)}") - else: - output_transcript(o) - now = None - elif args.comp_unaware: # computational unaware mode - end = beg + min_chunk - while True: - a = load_audio_chunk(audio_path, beg, end) - online.insert_audio_chunk(a) - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {repr(e)}") - pass - else: - output_transcript(o, now=end) - - logger.debug(f"## last processed {end:.2f}s") - - if end >= duration: - break - - beg = end - - if end + min_chunk > duration: - end = duration - else: - end += min_chunk - now = duration - - else: # online = simultaneous mode - end = 0 - while True: - now = time.time() - start - if now < end + min_chunk: - time.sleep(min_chunk + end - now) - end = time.time() - start - a = load_audio_chunk(audio_path, beg, end) - beg = end - online.insert_audio_chunk(a) - - try: - o = online.process_iter() - except AssertionError as e: - logger.error(f"assertion error: {e}") - pass - else: - output_transcript(o) - now = time.time() - start - logger.debug( - f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}" - ) - - if end >= duration: - break - now = None - - o = online.finish() - output_transcript(o, now=now)