Merge pull request #47 from QuentinFuxa/logging-and-MAX_BYTES_PER_SEC

Implement logging for WebSocket events and FFmpeg process management;…
This commit is contained in:
Quentin Fuxa
2025-02-15 15:22:52 +01:00
committed by GitHub
2 changed files with 25 additions and 193 deletions

View File

@@ -14,8 +14,14 @@ from src.whisper_streaming.whisper_online import backend_factory, online_factory
import subprocess
import math
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
##### LOAD ARGS #####
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
@@ -51,6 +57,7 @@ CHANNELS = 1
SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
if args.diarization:
from src.diarization.diarization_online import DiartDiarization
@@ -106,7 +113,7 @@ async def get():
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("WebSocket connection opened.")
logger.info("WebSocket connection opened.")
ffmpeg_process = None
pcm_buffer = bytearray()
@@ -120,13 +127,13 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process.kill()
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
except Exception as e:
print(f"Error killing FFmpeg process: {e}")
logger.warning(f"Error killing FFmpeg process: {e}")
ffmpeg_process = await start_ffmpeg_decoder()
pcm_buffer = bytearray()
online = online_factory(args, asr, tokenizer)
if args.diarization:
diarization = DiartDiarization(SAMPLE_RATE)
print("FFmpeg process started.")
logger.info("FFmpeg process started.")
await restart_ffmpeg()
@@ -153,7 +160,7 @@ async def websocket_endpoint(websocket: WebSocket):
timeout=5.0
)
except asyncio.TimeoutError:
print("FFmpeg read timeout. Restarting...")
logger.warning("FFmpeg read timeout. Restarting...")
await restart_ffmpeg()
full_transcription = ""
chunk_history = []
@@ -161,17 +168,23 @@ async def websocket_endpoint(websocket: WebSocket):
continue # Skip processing and read from new process
if not chunk:
print("FFmpeg stdout closed.")
logger.info("FFmpeg stdout closed.")
break
pcm_buffer.extend(chunk)
if len(pcm_buffer) >= BYTES_PER_SEC:
if len(pcm_buffer) > MAX_BYTES_PER_SEC:
logger.warning(
f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds.
The model probably struggles to keep up. Consider using a smaller model.
""")
# Convert int16 -> float32
pcm_array = (
np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
np.frombuffer(pcm_buffer[:MAX_BYTES_PER_SEC], dtype=np.int16).astype(np.float32)
/ 32768.0
)
pcm_buffer = bytearray()
pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
online.insert_audio_chunk(pcm_array)
transcription = online.process_iter()
@@ -215,10 +228,10 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json(response)
except Exception as e:
print(f"Exception in ffmpeg_stdout_reader: {e}")
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
break
print("Exiting ffmpeg_stdout_reader...")
logger.info("Exiting ffmpeg_stdout_reader...")
stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
@@ -230,12 +243,12 @@ async def websocket_endpoint(websocket: WebSocket):
ffmpeg_process.stdin.write(message)
ffmpeg_process.stdin.flush()
except (BrokenPipeError, AttributeError) as e:
print(f"Error writing to FFmpeg: {e}. Restarting...")
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
await restart_ffmpeg()
ffmpeg_process.stdin.write(message)
ffmpeg_process.stdin.flush()
except WebSocketDisconnect:
print("WebSocket disconnected.")
logger.warning("WebSocket disconnected.")
finally:
stdout_reader_task.cancel()
try:
@@ -254,4 +267,4 @@ if __name__ == "__main__":
uvicorn.run(
"whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
log_level="info"
)
)

View File

@@ -1,181 +0,0 @@
#!/usr/bin/env python3
import sys
import numpy as np
import librosa
from functools import lru_cache
import time
import logging
logger = logging.getLogger(__name__)
from src.whisper_streaming.whisper_online import *
@lru_cache(10**6)
def load_audio(fname):
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
return a
def load_audio_chunk(fname, beg, end):
audio = load_audio(fname)
beg_s = int(beg * 16000)
end_s = int(end * 16000)
return audio[beg_s:end_s]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_path",
type=str,
default='samples_jfk.wav',
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
)
add_shared_args(parser)
parser.add_argument(
"--start_at",
type=float,
default=0.0,
help="Start processing audio at this time.",
)
parser.add_argument(
"--offline", action="store_true", default=False, help="Offline mode."
)
parser.add_argument(
"--comp_unaware",
action="store_true",
default=False,
help="Computationally unaware simulation.",
)
args = parser.parse_args()
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
logfile = None # sys.stderr
if args.offline and args.comp_unaware:
logger.error(
"No or one option from --offline and --comp_unaware are available, not both. Exiting."
)
sys.exit(1)
# if args.log_level:
# logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
# level=getattr(logging, args.log_level))
set_logging(args, logger,others=["src.whisper_streaming.online_asr"])
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path)) / SAMPLING_RATE
logger.info("Audio duration is: %2.2f seconds" % duration)
asr, online = asr_factory(args, logfile=logfile)
if args.vac:
min_chunk = args.vac_chunk_size
else:
min_chunk = args.min_chunk_size
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path, 0, 1)
# warm up the ASR because the very first transcribe takes much more time than the other
asr.transcribe(a)
beg = args.start_at
start = time.time() - beg
def output_transcript(o, now=None):
# output format in stdout is like:
# 4186.3606 0 1720 Takhle to je
# - the first three words are:
# - emission time from beginning of processing, in milliseconds
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
# - the next words: segment transcript
if now is None:
now = time.time() - start
if o[0] is not None:
log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}"
logger.debug(
log_string
)
if logfile is not None:
print(
log_string,
file=logfile,
flush=True,
)
else:
# No text, so no output
pass
if args.offline: ## offline mode processing (for testing/debugging)
a = load_audio(audio_path)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
else:
output_transcript(o)
now = None
elif args.comp_unaware: # computational unaware mode
end = beg + min_chunk
while True:
a = load_audio_chunk(audio_path, beg, end)
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {repr(e)}")
pass
else:
output_transcript(o, now=end)
logger.debug(f"## last processed {end:.2f}s")
if end >= duration:
break
beg = end
if end + min_chunk > duration:
end = duration
else:
end += min_chunk
now = duration
else: # online = simultaneous mode
end = 0
while True:
now = time.time() - start
if now < end + min_chunk:
time.sleep(min_chunk + end - now)
end = time.time() - start
a = load_audio_chunk(audio_path, beg, end)
beg = end
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"assertion error: {e}")
pass
else:
output_transcript(o)
now = time.time() - start
logger.debug(
f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
)
if end >= duration:
break
now = None
o = online.finish()
output_transcript(o, now=now)