From dc02bcdbdd91afe36fa9d1fce0ee17df5cc5d1a4 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 18 Mar 2025 18:31:23 +0100 Subject: [PATCH 1/4] refacto 0 --- audio.py | 306 +++++++++++++++ diarization/diarization_online.py | 6 +- formatters.py | 91 +++++ parse_args.py | 52 +++ state.py | 96 +++++ whisper_fastapi_online_server.py | 431 ++------------------- whisper_streaming_custom/whisper_online.py | 2 +- 7 files changed, 576 insertions(+), 408 deletions(-) create mode 100644 audio.py create mode 100644 formatters.py create mode 100644 parse_args.py create mode 100644 state.py diff --git a/audio.py b/audio.py new file mode 100644 index 0000000..1b97746 --- /dev/null +++ b/audio.py @@ -0,0 +1,306 @@ +import io +import argparse +import asyncio +import numpy as np +import ffmpeg +from time import time, sleep +from contextlib import asynccontextmanager + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse +from fastapi.middleware.cors import CORSMiddleware + +from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr +from timed_objects import ASRToken + +import math +import logging +from datetime import timedelta +import traceback +from state import SharedState +from formatters import format_time +from parse_args import parse_args + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.getLogger().setLevel(logging.WARNING) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class AudioProcessor: + + def __init__(self, args, asr, tokenizer): + self.args = args + self.sample_rate = 16000 + self.channels = 1 + self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) + self.bytes_per_sample = 2 + self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample + self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + self.shared_state = SharedState() + self.asr = asr + self.tokenizer = tokenizer + + def convert_pcm_to_float(self, pcm_buffer): + """ + Converts a PCM buffer in s16le format to a normalized NumPy array. + Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format + Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0 + """ + pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) + / 32768.0) + return pcm_array + + async def start_ffmpeg_decoder(self): + """ + Start an FFmpeg process in async streaming mode that reads WebM from stdin + and outputs raw s16le PCM on stdout. Returns the process object. + """ + process = ( + ffmpeg.input("pipe:0", format="webm") + .output( + "pipe:1", + format="s16le", + acodec="pcm_s16le", + ac=self.channels, + ar=str(self.sample_rate), + ) + .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) + ) + return process + + async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer): + if ffmpeg_process: + try: + ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) + except Exception as e: + logger.warning(f"Error killing FFmpeg process: {e}") + ffmpeg_process = await self.start_ffmpeg_decoder() + pcm_buffer = bytearray() + + if self.args.transcription: + online = online_factory(self.args, self.asr, self.tokenizer) + + await self.shared_state.reset() + logger.info("FFmpeg process started.") + return ffmpeg_process, online, pcm_buffer + + + + async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue): + loop = asyncio.get_event_loop() + beg = time() + + while True: + try: + elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec + ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) + beg = time() + + # Read chunk with timeout + try: + chunk = await asyncio.wait_for( + loop.run_in_executor( + None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration + ), + timeout=15.0 + ) + except asyncio.TimeoutError: + logger.warning("FFmpeg read timeout. Restarting...") + ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) + beg = time() + continue # Skip processing and read from new process + + if not chunk: + logger.info("FFmpeg stdout closed.") + break + pcm_buffer.extend(chunk) + + if self.args.diarization and diarization_queue: + await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy()) + + if len(pcm_buffer) >= self.bytes_per_sec: + if len(pcm_buffer) > self.max_bytes_per_sec: + logger.warning( + f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds. + The model probably struggles to keep up. Consider using a smaller model. + """) + + pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec]) + pcm_buffer = pcm_buffer[self.max_bytes_per_sec:] + + if self.args.transcription and transcription_queue: + await transcription_queue.put(pcm_array.copy()) + + + if not self.args.transcription and not self.args.diarization: + await asyncio.sleep(0.1) + + except Exception as e: + logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + break + logger.info("Exiting ffmpeg_stdout_reader...") + + + + + async def transcription_processor(self, pcm_queue, online): + full_transcription = "" + sep = online.asr.sep + + while True: + try: + pcm_array = await pcm_queue.get() + + logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") + + # Process transcription + online.insert_audio_chunk(pcm_array) + new_tokens = online.process_iter() + + if new_tokens: + full_transcription += sep.join([t.text for t in new_tokens]) + + _buffer = online.get_buffer() + buffer = _buffer.text + end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) + + if buffer in full_transcription: + buffer = "" + + await self.shared_state.update_transcription( + new_tokens, buffer, end_buffer, full_transcription, sep) + + except Exception as e: + logger.warning(f"Exception in transcription_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + pcm_queue.task_done() + + async def diarization_processor(self, pcm_queue, diarization_obj): + buffer_diarization = "" + + while True: + try: + pcm_array = await pcm_queue.get() + + # Process diarization + await diarization_obj.diarize(pcm_array) + + # Get current state + state = await self.shared_state.get_current_state() + tokens = state["tokens"] + end_attributed_speaker = state["end_attributed_speaker"] + + # Update speaker information + new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens( + end_attributed_speaker, tokens) + + await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization) + + except Exception as e: + logger.warning(f"Exception in diarization_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + pcm_queue.task_done() + + async def results_formatter(self, websocket): + while True: + try: + # Get the current state + state = await self.shared_state.get_current_state() + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + remaining_time_transcription = state["remaining_time_transcription"] + remaining_time_diarization = state["remaining_time_diarization"] + sep = state["sep"] + + # If diarization is enabled but no transcription, add dummy tokens periodically + if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: + await self.shared_state.add_dummy_token() + sleep(0.5) + state = await self.shared_state.get_current_state() + tokens = state["tokens"] + # Process tokens to create response + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + + for token in tokens: + speaker = token.speaker + if self.args.diarization: + if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: + speaker = previous_speaker + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + if speaker != previous_speaker or not lines: + lines.append( + { + "speaker": speaker, + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + } + ) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + if undiarized_text: + combined_buffer_diarization = sep.join(undiarized_text) + if buffer_transcription: + combined_buffer_diarization += sep + await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) + buffer_diarization = combined_buffer_diarization + + if lines: + response = { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + } + else: + response = { + "lines": [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(tokens[-1].end) if tokens else format_time(0), + "diff": 0 + }], + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + + } + + response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization + + if response_content != self.shared_state.last_response_content: + if lines or buffer_transcription or buffer_diarization: + await websocket.send_json(response) + self.shared_state.last_response_content = response_content + + # Add a small delay to avoid overwhelming the client + await asyncio.sleep(0.1) + + except Exception as e: + logger.warning(f"Exception in results_formatter: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + await asyncio.sleep(0.5) # Back off on error diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py index d04c3a2..45bec13 100644 --- a/diarization/diarization_online.py +++ b/diarization/diarization_online.py @@ -5,7 +5,7 @@ import numpy as np import logging -from diart import SpeakerDiarization +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.inference import StreamingInference from diart.sources import AudioSource from timed_objects import SpeakerSegment @@ -103,8 +103,8 @@ class WebSocketAudioSource(AudioSource): class DiartDiarization: - def __init__(self, sample_rate: int, use_microphone: bool = False): - self.pipeline = SpeakerDiarization() + def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): + self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() if use_microphone: diff --git a/formatters.py b/formatters.py new file mode 100644 index 0000000..d48473f --- /dev/null +++ b/formatters.py @@ -0,0 +1,91 @@ +from typing import Dict, Any, List +from datetime import timedelta + +def format_time(seconds: float) -> str: + """Format seconds as HH:MM:SS.""" + return str(timedelta(seconds=int(seconds))) + +def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]: + """ + Format the shared state into a client-friendly response. + + Args: + state: Current shared state dictionary + with_diarization: Whether to include diarization formatting + + Returns: + Formatted response dictionary ready to send to client + """ + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + remaining_time_transcription = state["remaining_time_transcription"] + remaining_time_diarization = state["remaining_time_diarization"] + sep = state["sep"] + + # Default response for empty state + if not tokens: + return { + "lines": [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(0), + "diff": 0 + }], + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + } + + # Process tokens to create response + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + + for token in tokens: + speaker = token.speaker + + # Handle diarization logic + if with_diarization: + if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: + speaker = previous_speaker + + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + # Add new line or append to existing line + if speaker != previous_speaker or not lines: + lines.append({ + "speaker": speaker, + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + }) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + # If we have undiarized text, include it in the buffer + if undiarized_text: + combined_buffer = sep.join(undiarized_text) + if buffer_transcription: + combined_buffer += sep + buffer_transcription + buffer_diarization = combined_buffer + + return { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + } \ No newline at end of file diff --git a/parse_args.py b/parse_args.py new file mode 100644 index 0000000..f201477 --- /dev/null +++ b/parse_args.py @@ -0,0 +1,52 @@ + +import argparse +from whisper_streaming_custom.whisper_online import add_shared_args + + +def parse_args(): + parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") + parser.add_argument( + "--host", + type=str, + default="localhost", + help="The host address to bind the server to.", + ) + parser.add_argument( + "--port", type=int, default=8000, help="The port number to bind the server to." + ) + parser.add_argument( + "--warmup-file", + type=str, + default=None, + dest="warmup_file", + help=""" + The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. + If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav. + If False, no warmup is performed. + """, + ) + + parser.add_argument( + "--confidence-validation", + type=bool, + default=False, + help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", + ) + + parser.add_argument( + "--diarization", + type=bool, + default=True, + help="Whether to enable speaker diarization.", + ) + + parser.add_argument( + "--transcription", + type=bool, + default=True, + help="To disable to only see live diarization results.", + ) + + add_shared_args(parser) + args = parser.parse_args() + return args \ No newline at end of file diff --git a/state.py b/state.py new file mode 100644 index 0000000..a4c864a --- /dev/null +++ b/state.py @@ -0,0 +1,96 @@ +import asyncio +import logging +from time import time +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field +from timed_objects import ASRToken + +logger = logging.getLogger(__name__) + + +class SharedState: + """ + Thread-safe state manager for streaming transcription and diarization. + Handles coordination between audio processing, transcription, and diarization. + """ + + def __init__(self): + self.tokens: List[ASRToken] = [] + self.buffer_transcription: str = "" + self.buffer_diarization: str = "" + self.full_transcription: str = "" + self.end_buffer: float = 0 + self.end_attributed_speaker: float = 0 + self.lock = asyncio.Lock() + self.beg_loop: float = time() + self.sep: str = " " # Default separator + self.last_response_content: str = "" # To track changes in response + + async def update_transcription(self, new_tokens: List[ASRToken], buffer: str, + end_buffer: float, full_transcription: str, sep: str) -> None: + """Update the state with new transcription data.""" + async with self.lock: + self.tokens.extend(new_tokens) + self.buffer_transcription = buffer + self.end_buffer = end_buffer + self.full_transcription = full_transcription + self.sep = sep + + async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None: + """Update the state with new diarization data.""" + async with self.lock: + self.end_attributed_speaker = end_attributed_speaker + if buffer_diarization: + self.buffer_diarization = buffer_diarization + + async def add_dummy_token(self) -> None: + """Add a dummy token to keep the state updated even without transcription.""" + async with self.lock: + current_time = time() - self.beg_loop + dummy_token = ASRToken( + start=current_time, + end=current_time + 1, + text=".", + speaker=-1, + is_dummy=True + ) + self.tokens.append(dummy_token) + + async def get_current_state(self) -> Dict[str, Any]: + """Get the current state with calculated timing information.""" + async with self.lock: + current_time = time() + remaining_time_transcription = 0 + remaining_time_diarization = 0 + + # Calculate remaining time for transcription buffer + if self.end_buffer > 0: + remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) + + # Calculate remaining time for diarization + if self.tokens: + latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) + remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) + + return { + "tokens": self.tokens.copy(), + "buffer_transcription": self.buffer_transcription, + "buffer_diarization": self.buffer_diarization, + "end_buffer": self.end_buffer, + "end_attributed_speaker": self.end_attributed_speaker, + "sep": self.sep, + "remaining_time_transcription": remaining_time_transcription, + "remaining_time_diarization": remaining_time_diarization + } + + async def reset(self) -> None: + """Reset the state to initial values.""" + async with self.lock: + self.tokens = [] + self.buffer_transcription = "" + self.buffer_diarization = "" + self.end_buffer = 0 + self.end_attributed_speaker = 0 + self.full_transcription = "" + self.beg_loop = time() + self.last_response_content = "" \ No newline at end of file diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 24597bc..99eb4d4 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -17,147 +17,28 @@ import math import logging from datetime import timedelta import traceback - -def format_time(seconds): - return str(timedelta(seconds=int(seconds))) - +from state import SharedState +from formatters import format_time +from parse_args import parse_args +from audio import AudioProcessor logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -##### LOAD ARGS ##### -parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") -parser.add_argument( - "--host", - type=str, - default="localhost", - help="The host address to bind the server to.", -) -parser.add_argument( - "--port", type=int, default=8000, help="The port number to bind the server to." -) -parser.add_argument( - "--warmup-file", - type=str, - default=None, - dest="warmup_file", - help=""" - The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. - If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav. - If False, no warmup is performed. - """, -) -parser.add_argument( - "--confidence-validation", - type=bool, - default=False, - help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", -) - -parser.add_argument( - "--diarization", - type=bool, - default=False, - help="Whether to enable speaker diarization.", -) - -parser.add_argument( - "--transcription", - type=bool, - default=True, - help="To disable to only see live diarization results.", -) - -add_shared_args(parser) -args = parser.parse_args() +args = parse_args() SAMPLE_RATE = 16000 -CHANNELS = 1 -SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size) -BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample -BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE -MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz +# CHANNELS = 1 +# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size) +# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample +# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE +# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz -class SharedState: - def __init__(self): - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.full_transcription = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.lock = asyncio.Lock() - self.beg_loop = time() - self.sep = " " # Default separator - self.last_response_content = "" # To track changes in response - - async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): - async with self.lock: - self.tokens.extend(new_tokens) - self.buffer_transcription = buffer - self.end_buffer = end_buffer - self.full_transcription = full_transcription - self.sep = sep - - async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): - async with self.lock: - self.end_attributed_speaker = end_attributed_speaker - if buffer_diarization: - self.buffer_diarization = buffer_diarization - - async def add_dummy_token(self): - async with self.lock: - current_time = time() - self.beg_loop - dummy_token = ASRToken( - start=current_time, - end=current_time + 1, - text=".", - speaker=-1, - is_dummy=True - ) - self.tokens.append(dummy_token) - - async def get_current_state(self): - async with self.lock: - current_time = time() - remaining_time_transcription = 0 - remaining_time_diarization = 0 - - # Calculate remaining time for transcription buffer - if self.end_buffer > 0: - remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) - - # Calculate remaining time for diarization - remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2)) - - return { - "tokens": self.tokens.copy(), - "buffer_transcription": self.buffer_transcription, - "buffer_diarization": self.buffer_diarization, - "end_buffer": self.end_buffer, - "end_attributed_speaker": self.end_attributed_speaker, - "sep": self.sep, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - async def reset(self): - """Reset the state.""" - async with self.lock: - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.full_transcription = "" - self.beg_loop = time() - self.last_response_content = "" - ##### LOAD APP ##### @asynccontextmanager @@ -190,300 +71,45 @@ app.add_middleware( with open("web/live_transcription.html", "r", encoding="utf-8") as f: html = f.read() -def convert_pcm_to_float(pcm_buffer): - """ - Converts a PCM buffer in s16le format to a normalized NumPy array. - Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format - Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0 - """ - pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) - / 32768.0) - return pcm_array - -async def start_ffmpeg_decoder(): - """ - Start an FFmpeg process in async streaming mode that reads WebM from stdin - and outputs raw s16le PCM on stdout. Returns the process object. - """ - process = ( - ffmpeg.input("pipe:0", format="webm") - .output( - "pipe:1", - format="s16le", - acodec="pcm_s16le", - ac=CHANNELS, - ar=str(SAMPLE_RATE), - ) - .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) - ) - return process - -async def transcription_processor(shared_state, pcm_queue, online): - full_transcription = "" - sep = online.asr.sep - - while True: - try: - pcm_array = await pcm_queue.get() - - logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") - - # Process transcription - online.insert_audio_chunk(pcm_array) - new_tokens = online.process_iter() - - if new_tokens: - full_transcription += sep.join([t.text for t in new_tokens]) - - _buffer = online.get_buffer() - buffer = _buffer.text - end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) - - if buffer in full_transcription: - buffer = "" - - await shared_state.update_transcription( - new_tokens, buffer, end_buffer, full_transcription, sep) - - except Exception as e: - logger.warning(f"Exception in transcription_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - pcm_queue.task_done() - -async def diarization_processor(shared_state, pcm_queue, diarization_obj): - buffer_diarization = "" - - while True: - try: - pcm_array = await pcm_queue.get() - - # Process diarization - await diarization_obj.diarize(pcm_array) - - # Get current state - state = await shared_state.get_current_state() - tokens = state["tokens"] - end_attributed_speaker = state["end_attributed_speaker"] - - # Update speaker information - new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens( - end_attributed_speaker, tokens) - - await shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization) - - except Exception as e: - logger.warning(f"Exception in diarization_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - pcm_queue.task_done() - -async def results_formatter(shared_state, websocket): - while True: - try: - # Get the current state - state = await shared_state.get_current_state() - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # If diarization is enabled but no transcription, add dummy tokens periodically - if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization: - await shared_state.add_dummy_token() - sleep(0.5) - state = await shared_state.get_current_state() - tokens = state["tokens"] - # Process tokens to create response - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - if args.diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - if speaker != previous_speaker or not lines: - lines.append( - { - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - } - ) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - if undiarized_text: - combined_buffer_diarization = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer_diarization += sep - await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) - buffer_diarization = combined_buffer_diarization - - if lines: - response = { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - else: - response = { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(tokens[-1].end) if tokens else format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - - } - - response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization - - if response_content != shared_state.last_response_content: - if lines or buffer_transcription or buffer_diarization: - await websocket.send_json(response) - shared_state.last_response_content = response_content - - # Add a small delay to avoid overwhelming the client - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in results_formatter: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - await asyncio.sleep(0.5) # Back off on error - -##### ENDPOINTS ##### - @app.get("/") async def get(): return HTMLResponse(html) + + + + + + + @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): + audio_processor = AudioProcessor(args, asr, tokenizer) + await websocket.accept() logger.info("WebSocket connection opened.") ffmpeg_process = None pcm_buffer = bytearray() - shared_state = SharedState() transcription_queue = asyncio.Queue() if args.transcription else None diarization_queue = asyncio.Queue() if args.diarization else None online = None - async def restart_ffmpeg(): - nonlocal ffmpeg_process, online, pcm_buffer - if ffmpeg_process: - try: - ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) - except Exception as e: - logger.warning(f"Error killing FFmpeg process: {e}") - ffmpeg_process = await start_ffmpeg_decoder() - pcm_buffer = bytearray() - - if args.transcription: - online = online_factory(args, asr, tokenizer) - - await shared_state.reset() - logger.info("FFmpeg process started.") - - await restart_ffmpeg() - + ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) tasks = [] if args.transcription and online: tasks.append(asyncio.create_task( - transcription_processor(shared_state, transcription_queue, online))) + audio_processor.transcription_processor(transcription_queue, online))) if args.diarization and diarization: tasks.append(asyncio.create_task( - diarization_processor(shared_state, diarization_queue, diarization))) - formatter_task = asyncio.create_task(results_formatter(shared_state, websocket)) + audio_processor.diarization_processor(diarization_queue, diarization))) + formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket)) tasks.append(formatter_task) - - async def ffmpeg_stdout_reader(): - nonlocal ffmpeg_process, pcm_buffer - loop = asyncio.get_event_loop() - beg = time() - - while True: - try: - elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec - ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) - beg = time() - - # Read chunk with timeout - try: - chunk = await asyncio.wait_for( - loop.run_in_executor( - None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration - ), - timeout=15.0 - ) - except asyncio.TimeoutError: - logger.warning("FFmpeg read timeout. Restarting...") - await restart_ffmpeg() - beg = time() - continue # Skip processing and read from new process - - if not chunk: - logger.info("FFmpeg stdout closed.") - break - pcm_buffer.extend(chunk) - - if args.diarization and diarization_queue: - await diarization_queue.put(convert_pcm_to_float(pcm_buffer).copy()) - - if len(pcm_buffer) >= BYTES_PER_SEC: - if len(pcm_buffer) > MAX_BYTES_PER_SEC: - logger.warning( - f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds. - The model probably struggles to keep up. Consider using a smaller model. - """) - - pcm_array = convert_pcm_to_float(pcm_buffer[:MAX_BYTES_PER_SEC]) - pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:] - - if args.transcription and transcription_queue: - await transcription_queue.put(pcm_array.copy()) - - - if not args.transcription and not args.diarization: - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - break - - logger.info("Exiting ffmpeg_stdout_reader...") - - stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) - tasks.append(stdout_reader_task) + stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue)) + tasks.append(stdout_reader_task) + try: while True: # Receive incoming WebM audio chunks from the client @@ -493,7 +119,7 @@ async def websocket_endpoint(websocket: WebSocket): ffmpeg_process.stdin.flush() except (BrokenPipeError, AttributeError) as e: logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - await restart_ffmpeg() + ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) ffmpeg_process.stdin.write(message) ffmpeg_process.stdin.flush() except WebSocketDisconnect: @@ -501,17 +127,14 @@ async def websocket_endpoint(websocket: WebSocket): finally: for task in tasks: task.cancel() - try: await asyncio.gather(*tasks, return_exceptions=True) ffmpeg_process.stdin.close() ffmpeg_process.wait() except Exception as e: logger.warning(f"Error during cleanup: {e}") - if args.diarization and diarization: diarization.close() - logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__": diff --git a/whisper_streaming_custom/whisper_online.py b/whisper_streaming_custom/whisper_online.py index 617f05b..d7263ac 100644 --- a/whisper_streaming_custom/whisper_online.py +++ b/whisper_streaming_custom/whisper_online.py @@ -71,7 +71,7 @@ def add_shared_args(parser): parser.add_argument( "--min-chunk-size", type=float, - default=1.0, + default=0.5, help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.", ) parser.add_argument( From 5ca65e21b725e102958deb6ded37f6bb04b928d7 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Mar 2025 10:33:22 +0100 Subject: [PATCH 2/4] Refactor DiartDiarization initialization and streamline WebSocket audio processing --- audio.py | 136 ++++++++++++++++++------------ diarization/diarization_online.py | 2 +- whisper_fastapi_online_server.py | 80 ++---------------- 3 files changed, 89 insertions(+), 129 deletions(-) diff --git a/audio.py b/audio.py index 1b97746..ee6ca56 100644 --- a/audio.py +++ b/audio.py @@ -1,25 +1,15 @@ -import io -import argparse import asyncio import numpy as np import ffmpeg from time import time, sleep -from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse -from fastapi.middleware.cors import CORSMiddleware - -from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr -from timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import online_factory import math import logging -from datetime import timedelta import traceback from state import SharedState from formatters import format_time -from parse_args import parse_args logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - class AudioProcessor: def __init__(self, args, asr, tokenizer): @@ -38,9 +27,22 @@ class AudioProcessor: self.bytes_per_sample = 2 self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + + self.shared_state = SharedState() self.asr = asr self.tokenizer = tokenizer + + self.ffmpeg_process = self.start_ffmpeg_decoder() + + self.transcription_queue = asyncio.Queue() if self.args.transcription else None + self.diarization_queue = asyncio.Queue() if self.args.diarization else None + + self.pcm_buffer = bytearray() + if self.args.transcription: + self.online = online_factory(self.args, self.asr, self.tokenizer) + + def convert_pcm_to_float(self, pcm_buffer): """ @@ -70,26 +72,17 @@ class AudioProcessor: ) return process - async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer): - if ffmpeg_process: + async def restart_ffmpeg(self): + if self.ffmpeg_process: try: - ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) + self.ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) except Exception as e: logger.warning(f"Error killing FFmpeg process: {e}") - ffmpeg_process = await self.start_ffmpeg_decoder() - pcm_buffer = bytearray() - - if self.args.transcription: - online = online_factory(self.args, self.asr, self.tokenizer) - - await self.shared_state.reset() - logger.info("FFmpeg process started.") - return ffmpeg_process, online, pcm_buffer + self.ffmpeg_process = await self.start_ffmpeg_decoder() + self.pcm_buffer = bytearray() - - - async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue): + async def ffmpeg_stdout_reader(self): loop = asyncio.get_event_loop() beg = time() @@ -103,36 +96,36 @@ class AudioProcessor: try: chunk = await asyncio.wait_for( loop.run_in_executor( - None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration + None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration ), timeout=15.0 ) except asyncio.TimeoutError: logger.warning("FFmpeg read timeout. Restarting...") - ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) + await self.restart_ffmpeg() beg = time() continue # Skip processing and read from new process if not chunk: logger.info("FFmpeg stdout closed.") break - pcm_buffer.extend(chunk) + self.pcm_buffer.extend(chunk) - if self.args.diarization and diarization_queue: - await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy()) + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy()) - if len(pcm_buffer) >= self.bytes_per_sec: - if len(pcm_buffer) > self.max_bytes_per_sec: + if len(self.pcm_buffer) >= self.bytes_per_sec: + if len(self.pcm_buffer) > self.max_bytes_per_sec: logger.warning( - f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds. + f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds. The model probably struggles to keep up. Consider using a smaller model. """) - pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec]) - pcm_buffer = pcm_buffer[self.max_bytes_per_sec:] + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) + self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] - if self.args.transcription and transcription_queue: - await transcription_queue.put(pcm_array.copy()) + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) if not self.args.transcription and not self.args.diarization: @@ -144,27 +137,24 @@ class AudioProcessor: break logger.info("Exiting ffmpeg_stdout_reader...") - - - - async def transcription_processor(self, pcm_queue, online): + async def transcription_processor(self): full_transcription = "" - sep = online.asr.sep + sep = self.online.asr.sep while True: try: - pcm_array = await pcm_queue.get() + pcm_array = await self.transcription_queue.get() - logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") + logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.") # Process transcription - online.insert_audio_chunk(pcm_array) - new_tokens = online.process_iter() + self.online.insert_audio_chunk(pcm_array) + new_tokens = self.online.process_iter() if new_tokens: full_transcription += sep.join([t.text for t in new_tokens]) - _buffer = online.get_buffer() + _buffer = self.online.get_buffer() buffer = _buffer.text end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) @@ -178,14 +168,15 @@ class AudioProcessor: logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: - pcm_queue.task_done() + self.transcription_queue.task_done() - async def diarization_processor(self, pcm_queue, diarization_obj): + + async def diarization_processor(self, diarization_obj): buffer_diarization = "" while True: try: - pcm_array = await pcm_queue.get() + pcm_array = await self.diarization_queue.get() # Process diarization await diarization_obj.diarize(pcm_array) @@ -205,7 +196,7 @@ class AudioProcessor: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: - pcm_queue.task_done() + self.diarization_queue.task_done() async def results_formatter(self, websocket): while True: @@ -304,3 +295,40 @@ class AudioProcessor: logger.warning(f"Exception in results_formatter: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error + + async def create_tasks(self, websocket, diarization): + tasks = [] + if self.args.transcription and self.online: + tasks.append(asyncio.create_task(self.transcription_processor())) + if self.args.diarization and diarization: + tasks.append(asyncio.create_task(self.diarization_processor(diarization))) + formatter_task = asyncio.create_task(self.results_formatter(websocket)) + tasks.append(formatter_task) + stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) + tasks.append(stdout_reader_task) + self.tasks = tasks + self.diarization = diarization + + async def cleanup(self): + for task in self.tasks: + task.cancel() + try: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.ffmpeg_process.stdin.close() + self.ffmpeg_process.wait() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + if self.args.diarization and self.diarization: + self.diarization.close() + + async def process_audio(self, message): + try: + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + except (BrokenPipeError, AttributeError) as e: + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") + await self.restart_ffmpeg() + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + + \ No newline at end of file diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py index 45bec13..622fb15 100644 --- a/diarization/diarization_online.py +++ b/diarization/diarization_online.py @@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource): class DiartDiarization: - def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): + def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 99eb4d4..b0ca658 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -1,24 +1,11 @@ -import io -import argparse -import asyncio -import numpy as np -import ffmpeg -from time import time, sleep from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware -from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr -from timed_objects import ASRToken - -import math +from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr import logging -from datetime import timedelta -import traceback -from state import SharedState -from formatters import format_time from parse_args import parse_args from audio import AudioProcessor @@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - - args = parse_args() -SAMPLE_RATE = 16000 -# CHANNELS = 1 -# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size) -# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample -# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE -# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz - - -##### LOAD APP ##### @asynccontextmanager async def lifespan(app: FastAPI): @@ -52,7 +28,7 @@ async def lifespan(app: FastAPI): if args.diarization: from diarization.diarization_online import DiartDiarization - diarization = DiartDiarization(SAMPLE_RATE) + diarization = DiartDiarization() else : diarization = None yield @@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f: async def get(): return HTMLResponse(html) - - - - - - - @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): audio_processor = AudioProcessor(args, asr, tokenizer) await websocket.accept() logger.info("WebSocket connection opened.") - - ffmpeg_process = None - pcm_buffer = bytearray() - - transcription_queue = asyncio.Queue() if args.transcription else None - diarization_queue = asyncio.Queue() if args.diarization else None - - online = None - - ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) - tasks = [] - if args.transcription and online: - tasks.append(asyncio.create_task( - audio_processor.transcription_processor(transcription_queue, online))) - if args.diarization and diarization: - tasks.append(asyncio.create_task( - audio_processor.diarization_processor(diarization_queue, diarization))) - formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket)) - tasks.append(formatter_task) - stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue)) - tasks.append(stdout_reader_task) - + + await audio_processor.create_tasks(websocket, diarization) try: while True: - # Receive incoming WebM audio chunks from the client message = await websocket.receive_bytes() - try: - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() - except (BrokenPipeError, AttributeError) as e: - logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer) - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() + audio_processor.process_audio(message) except WebSocketDisconnect: logger.warning("WebSocket disconnected.") finally: - for task in tasks: - task.cancel() - try: - await asyncio.gather(*tasks, return_exceptions=True) - ffmpeg_process.stdin.close() - ffmpeg_process.wait() - except Exception as e: - logger.warning(f"Error during cleanup: {e}") - if args.diarization and diarization: - diarization.close() + audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__": From 7679370cf6c929af92e6a5e5577a392b2a2569aa Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Mar 2025 10:59:50 +0100 Subject: [PATCH 3/4] Refactor AudioProcessor methods for improved async handling and WebSocket integration --- audio.py | 31 ++++++++++++++++--------------- whisper_fastapi_online_server.py | 18 ++++++++++++++++-- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/audio.py b/audio.py index ee6ca56..2436685 100644 --- a/audio.py +++ b/audio.py @@ -54,7 +54,7 @@ class AudioProcessor: / 32768.0) return pcm_array - async def start_ffmpeg_decoder(self): + def start_ffmpeg_decoder(self): """ Start an FFmpeg process in async streaming mode that reads WebM from stdin and outputs raw s16le PCM on stdout. Returns the process object. @@ -79,7 +79,7 @@ class AudioProcessor: await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) except Exception as e: logger.warning(f"Error killing FFmpeg process: {e}") - self.ffmpeg_process = await self.start_ffmpeg_decoder() + self.ffmpeg_process = self.start_ffmpeg_decoder() self.pcm_buffer = bytearray() async def ffmpeg_stdout_reader(self): @@ -198,10 +198,9 @@ class AudioProcessor: finally: self.diarization_queue.task_done() - async def results_formatter(self, websocket): + async def results_formatter(self): while True: try: - # Get the current state state = await self.shared_state.get_current_state() tokens = state["tokens"] buffer_transcription = state["buffer_transcription"] @@ -217,7 +216,6 @@ class AudioProcessor: sleep(0.5) state = await self.shared_state.get_current_state() tokens = state["tokens"] - # Process tokens to create response previous_speaker = -1 lines = [] last_end_diarized = 0 @@ -273,22 +271,21 @@ class AudioProcessor: "beg": format_time(0), "end": format_time(tokens[-1].end) if tokens else format_time(0), "diff": 0 - }], + }], "buffer_transcription": buffer_transcription, "buffer_diarization": buffer_diarization, "remaining_time_transcription": remaining_time_transcription, "remaining_time_diarization": remaining_time_diarization - } response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization if response_content != self.shared_state.last_response_content: if lines or buffer_transcription or buffer_diarization: - await websocket.send_json(response) + yield response self.shared_state.last_response_content = response_content - # Add a small delay to avoid overwhelming the client + #small delay to avoid overwhelming the client await asyncio.sleep(0.1) except Exception as e: @@ -296,18 +293,22 @@ class AudioProcessor: logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error - async def create_tasks(self, websocket, diarization): + async def create_tasks(self, diarization=None): + if diarization: + self.diarization = diarization + tasks = [] if self.args.transcription and self.online: tasks.append(asyncio.create_task(self.transcription_processor())) - if self.args.diarization and diarization: - tasks.append(asyncio.create_task(self.diarization_processor(diarization))) - formatter_task = asyncio.create_task(self.results_formatter(websocket)) - tasks.append(formatter_task) + if self.args.diarization and self.diarization: + tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) + stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) tasks.append(stdout_reader_task) + self.tasks = tasks - self.diarization = diarization + + return self.results_formatter() async def cleanup(self): for task in self.tasks: diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index b0ca658..cf81ff8 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -5,6 +5,7 @@ from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr +import asyncio import logging from parse_args import parse_args from audio import AudioProcessor @@ -51,6 +52,16 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f: async def get(): return HTMLResponse(html) + +async def handle_websocket_results(websocket, results_generator): + """Consumes results from the audio processor and sends them via WebSocket.""" + try: + async for response in results_generator: + await websocket.send_json(response) + except Exception as e: + logger.warning(f"Error in WebSocket results handler: {e}") + + @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): audio_processor = AudioProcessor(args, asr, tokenizer) @@ -58,14 +69,17 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() logger.info("WebSocket connection opened.") - await audio_processor.create_tasks(websocket, diarization) + results_generator = await audio_processor.create_tasks(diarization) + websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) + try: while True: message = await websocket.receive_bytes() - audio_processor.process_audio(message) + await audio_processor.process_audio(message) except WebSocketDisconnect: logger.warning("WebSocket disconnected.") finally: + websocket_task.cancel() audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") From 5624c1f6b7185a1c92fc4fab0afc4a9272010183 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Mar 2025 11:18:12 +0100 Subject: [PATCH 4/4] Refactor import statement for AudioProcessor and update cleanup method to be awaited; remove unused formatters and state management files --- audio.py | 335 ------------------------- audio_processor.py | 406 +++++++++++++++++++++++++++++++ formatters.py | 91 ------- state.py | 96 -------- whisper_fastapi_online_server.py | 4 +- 5 files changed, 408 insertions(+), 524 deletions(-) delete mode 100644 audio.py create mode 100644 audio_processor.py delete mode 100644 formatters.py delete mode 100644 state.py diff --git a/audio.py b/audio.py deleted file mode 100644 index 2436685..0000000 --- a/audio.py +++ /dev/null @@ -1,335 +0,0 @@ -import asyncio -import numpy as np -import ffmpeg -from time import time, sleep - - -from whisper_streaming_custom.whisper_online import online_factory -import math -import logging -import traceback -from state import SharedState -from formatters import format_time - - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logging.getLogger().setLevel(logging.WARNING) -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -class AudioProcessor: - - def __init__(self, args, asr, tokenizer): - self.args = args - self.sample_rate = 16000 - self.channels = 1 - self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) - self.bytes_per_sample = 2 - self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample - self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz - - - self.shared_state = SharedState() - self.asr = asr - self.tokenizer = tokenizer - - self.ffmpeg_process = self.start_ffmpeg_decoder() - - self.transcription_queue = asyncio.Queue() if self.args.transcription else None - self.diarization_queue = asyncio.Queue() if self.args.diarization else None - - self.pcm_buffer = bytearray() - if self.args.transcription: - self.online = online_factory(self.args, self.asr, self.tokenizer) - - - - def convert_pcm_to_float(self, pcm_buffer): - """ - Converts a PCM buffer in s16le format to a normalized NumPy array. - Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format - Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0 - """ - pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) - / 32768.0) - return pcm_array - - def start_ffmpeg_decoder(self): - """ - Start an FFmpeg process in async streaming mode that reads WebM from stdin - and outputs raw s16le PCM on stdout. Returns the process object. - """ - process = ( - ffmpeg.input("pipe:0", format="webm") - .output( - "pipe:1", - format="s16le", - acodec="pcm_s16le", - ac=self.channels, - ar=str(self.sample_rate), - ) - .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) - ) - return process - - async def restart_ffmpeg(self): - if self.ffmpeg_process: - try: - self.ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) - except Exception as e: - logger.warning(f"Error killing FFmpeg process: {e}") - self.ffmpeg_process = self.start_ffmpeg_decoder() - self.pcm_buffer = bytearray() - - async def ffmpeg_stdout_reader(self): - loop = asyncio.get_event_loop() - beg = time() - - while True: - try: - elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec - ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) - beg = time() - - # Read chunk with timeout - try: - chunk = await asyncio.wait_for( - loop.run_in_executor( - None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration - ), - timeout=15.0 - ) - except asyncio.TimeoutError: - logger.warning("FFmpeg read timeout. Restarting...") - await self.restart_ffmpeg() - beg = time() - continue # Skip processing and read from new process - - if not chunk: - logger.info("FFmpeg stdout closed.") - break - self.pcm_buffer.extend(chunk) - - if self.args.diarization and self.diarization_queue: - await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy()) - - if len(self.pcm_buffer) >= self.bytes_per_sec: - if len(self.pcm_buffer) > self.max_bytes_per_sec: - logger.warning( - f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds. - The model probably struggles to keep up. Consider using a smaller model. - """) - - pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) - self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] - - if self.args.transcription and self.transcription_queue: - await self.transcription_queue.put(pcm_array.copy()) - - - if not self.args.transcription and not self.args.diarization: - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - break - logger.info("Exiting ffmpeg_stdout_reader...") - - async def transcription_processor(self): - full_transcription = "" - sep = self.online.asr.sep - - while True: - try: - pcm_array = await self.transcription_queue.get() - - logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.") - - # Process transcription - self.online.insert_audio_chunk(pcm_array) - new_tokens = self.online.process_iter() - - if new_tokens: - full_transcription += sep.join([t.text for t in new_tokens]) - - _buffer = self.online.get_buffer() - buffer = _buffer.text - end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) - - if buffer in full_transcription: - buffer = "" - - await self.shared_state.update_transcription( - new_tokens, buffer, end_buffer, full_transcription, sep) - - except Exception as e: - logger.warning(f"Exception in transcription_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - self.transcription_queue.task_done() - - - async def diarization_processor(self, diarization_obj): - buffer_diarization = "" - - while True: - try: - pcm_array = await self.diarization_queue.get() - - # Process diarization - await diarization_obj.diarize(pcm_array) - - # Get current state - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - end_attributed_speaker = state["end_attributed_speaker"] - - # Update speaker information - new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens( - end_attributed_speaker, tokens) - - await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization) - - except Exception as e: - logger.warning(f"Exception in diarization_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - self.diarization_queue.task_done() - - async def results_formatter(self): - while True: - try: - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # If diarization is enabled but no transcription, add dummy tokens periodically - if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: - await self.shared_state.add_dummy_token() - sleep(0.5) - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - if self.args.diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - if speaker != previous_speaker or not lines: - lines.append( - { - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - } - ) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - if undiarized_text: - combined_buffer_diarization = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer_diarization += sep - await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) - buffer_diarization = combined_buffer_diarization - - if lines: - response = { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - else: - response = { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(tokens[-1].end) if tokens else format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization - - if response_content != self.shared_state.last_response_content: - if lines or buffer_transcription or buffer_diarization: - yield response - self.shared_state.last_response_content = response_content - - #small delay to avoid overwhelming the client - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in results_formatter: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - await asyncio.sleep(0.5) # Back off on error - - async def create_tasks(self, diarization=None): - if diarization: - self.diarization = diarization - - tasks = [] - if self.args.transcription and self.online: - tasks.append(asyncio.create_task(self.transcription_processor())) - if self.args.diarization and self.diarization: - tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) - - stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) - tasks.append(stdout_reader_task) - - self.tasks = tasks - - return self.results_formatter() - - async def cleanup(self): - for task in self.tasks: - task.cancel() - try: - await asyncio.gather(*self.tasks, return_exceptions=True) - self.ffmpeg_process.stdin.close() - self.ffmpeg_process.wait() - except Exception as e: - logger.warning(f"Error during cleanup: {e}") - if self.args.diarization and self.diarization: - self.diarization.close() - - async def process_audio(self, message): - try: - self.ffmpeg_process.stdin.write(message) - self.ffmpeg_process.stdin.flush() - except (BrokenPipeError, AttributeError) as e: - logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - await self.restart_ffmpeg() - self.ffmpeg_process.stdin.write(message) - self.ffmpeg_process.stdin.flush() - - \ No newline at end of file diff --git a/audio_processor.py b/audio_processor.py new file mode 100644 index 0000000..94bfbef --- /dev/null +++ b/audio_processor.py @@ -0,0 +1,406 @@ +import asyncio +import numpy as np +import ffmpeg +from time import time, sleep +import math +import logging +import traceback +from datetime import timedelta +from typing import List, Dict, Any +from timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import online_factory + +# Set up logging once +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +def format_time(seconds: float) -> str: + """Format seconds as HH:MM:SS.""" + return str(timedelta(seconds=int(seconds))) + +class AudioProcessor: + """ + Processes audio streams for transcription and diarization. + Handles audio processing, state management, and result formatting in a single class. + """ + + def __init__(self, args, asr, tokenizer): + """Initialize the audio processor with configuration, models, and state.""" + # Audio processing settings + self.args = args + self.sample_rate = 16000 + self.channels = 1 + self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) + self.bytes_per_sample = 2 + self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample + self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + + # State management + self.tokens = [] + self.buffer_transcription = "" + self.buffer_diarization = "" + self.full_transcription = "" + self.end_buffer = 0 + self.end_attributed_speaker = 0 + self.lock = asyncio.Lock() + self.beg_loop = time() + self.sep = " " # Default separator + self.last_response_content = "" + + # Models and processing + self.asr = asr + self.tokenizer = tokenizer + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.transcription_queue = asyncio.Queue() if args.transcription else None + self.diarization_queue = asyncio.Queue() if args.diarization else None + self.pcm_buffer = bytearray() + + # Initialize transcription engine if enabled + if args.transcription: + self.online = online_factory(args, asr, tokenizer) + + def convert_pcm_to_float(self, pcm_buffer): + """Convert PCM buffer in s16le format to normalized NumPy array.""" + return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 + + def start_ffmpeg_decoder(self): + """Start FFmpeg process for WebM to PCM conversion.""" + return (ffmpeg.input("pipe:0", format="webm") + .output("pipe:1", format="s16le", acodec="pcm_s16le", + ac=self.channels, ar=str(self.sample_rate)) + .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)) + + async def restart_ffmpeg(self): + """Restart the FFmpeg process after failure.""" + if self.ffmpeg_process: + try: + self.ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) + except Exception as e: + logger.warning(f"Error killing FFmpeg process: {e}") + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.pcm_buffer = bytearray() + + async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): + """Thread-safe update of transcription with new data.""" + async with self.lock: + self.tokens.extend(new_tokens) + self.buffer_transcription = buffer + self.end_buffer = end_buffer + self.full_transcription = full_transcription + self.sep = sep + + async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): + """Thread-safe update of diarization with new data.""" + async with self.lock: + self.end_attributed_speaker = end_attributed_speaker + if buffer_diarization: + self.buffer_diarization = buffer_diarization + + async def add_dummy_token(self): + """Placeholder token when no transcription is available.""" + async with self.lock: + current_time = time() - self.beg_loop + self.tokens.append(ASRToken( + start=current_time, end=current_time + 1, + text=".", speaker=-1, is_dummy=True + )) + + async def get_current_state(self): + """Get current state.""" + async with self.lock: + current_time = time() + + # Calculate remaining times + remaining_transcription = 0 + if self.end_buffer > 0: + remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) + + remaining_diarization = 0 + if self.tokens: + latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) + remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) + + return { + "tokens": self.tokens.copy(), + "buffer_transcription": self.buffer_transcription, + "buffer_diarization": self.buffer_diarization, + "end_buffer": self.end_buffer, + "end_attributed_speaker": self.end_attributed_speaker, + "sep": self.sep, + "remaining_time_transcription": remaining_transcription, + "remaining_time_diarization": remaining_diarization + } + + async def reset(self): + """Reset all state variables to initial values.""" + async with self.lock: + self.tokens = [] + self.buffer_transcription = self.buffer_diarization = "" + self.end_buffer = self.end_attributed_speaker = 0 + self.full_transcription = self.last_response_content = "" + self.beg_loop = time() + + async def ffmpeg_stdout_reader(self): + """Read audio data from FFmpeg stdout and process it.""" + loop = asyncio.get_event_loop() + beg = time() + + while True: + try: + # Calculate buffer size based on elapsed time + elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec + buffer_size = max(int(32000 * elapsed_time), 4096) + beg = time() + + # Read chunk with timeout + try: + chunk = await asyncio.wait_for( + loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size), + timeout=15.0 + ) + except asyncio.TimeoutError: + logger.warning("FFmpeg read timeout. Restarting...") + await self.restart_ffmpeg() + beg = time() + continue + + if not chunk: + logger.info("FFmpeg stdout closed.") + break + + self.pcm_buffer.extend(chunk) + + # Send to diarization if enabled + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put( + self.convert_pcm_to_float(self.pcm_buffer).copy() + ) + + # Process when we have enough data + if len(self.pcm_buffer) >= self.bytes_per_sec: + if len(self.pcm_buffer) > self.max_bytes_per_sec: + logger.warning( + f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. " + f"Consider using a smaller model." + ) + + # Process audio chunk + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) + self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] + + # Send to transcription if enabled + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) + + # Sleep if no processing is happening + if not self.args.transcription and not self.args.diarization: + await asyncio.sleep(0.1) + + except Exception as e: + logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + break + + async def transcription_processor(self): + """Process audio chunks for transcription.""" + self.full_transcription = "" + self.sep = self.online.asr.sep + + while True: + try: + pcm_array = await self.transcription_queue.get() + + logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.") + + # Process transcription + self.online.insert_audio_chunk(pcm_array) + new_tokens = self.online.process_iter() + + if new_tokens: + self.full_transcription += self.sep.join([t.text for t in new_tokens]) + + # Get buffer information + _buffer = self.online.get_buffer() + buffer = _buffer.text + end_buffer = _buffer.end if _buffer.end else ( + new_tokens[-1].end if new_tokens else 0 + ) + + # Avoid duplicating content + if buffer in self.full_transcription: + buffer = "" + + await self.update_transcription( + new_tokens, buffer, end_buffer, self.full_transcription, self.sep + ) + + except Exception as e: + logger.warning(f"Exception in transcription_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.transcription_queue.task_done() + + async def diarization_processor(self, diarization_obj): + """Process audio chunks for speaker diarization.""" + buffer_diarization = "" + + while True: + try: + pcm_array = await self.diarization_queue.get() + + # Process diarization + await diarization_obj.diarize(pcm_array) + + # Get current state and update speakers + state = await self.get_current_state() + new_end = diarization_obj.assign_speakers_to_tokens( + state["end_attributed_speaker"], state["tokens"] + ) + + await self.update_diarization(new_end, buffer_diarization) + + except Exception as e: + logger.warning(f"Exception in diarization_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.diarization_queue.task_done() + + async def results_formatter(self): + """Format processing results for output.""" + while True: + try: + # Get current state + state = await self.get_current_state() + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + sep = state["sep"] + + # Add dummy tokens if needed + if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: + await self.add_dummy_token() + sleep(0.5) + state = await self.get_current_state() + tokens = state["tokens"] + + # Format output + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + + # Process each token + for token in tokens: + speaker = token.speaker + + # Handle diarization + if self.args.diarization: + if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: + speaker = previous_speaker + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + # Group by speaker + if speaker != previous_speaker or not lines: + lines.append({ + "speaker": speaker, + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + }) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + # Handle undiarized text + if undiarized_text: + combined = sep.join(undiarized_text) + if buffer_transcription: + combined += sep + await self.update_diarization(end_attributed_speaker, combined) + buffer_diarization = combined + + # Create response object + if not lines: + lines = [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(tokens[-1].end if tokens else 0), + "diff": 0 + }] + + response = { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": state["remaining_time_transcription"], + "remaining_time_diarization": state["remaining_time_diarization"] + } + + # Only yield if content has changed + response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \ + f" | {buffer_transcription} | {buffer_diarization}" + + if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization): + yield response + self.last_response_content = response_content + + await asyncio.sleep(0.1) # Avoid overwhelming the client + + except Exception as e: + logger.warning(f"Exception in results_formatter: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + await asyncio.sleep(0.5) # Back off on error + + async def create_tasks(self, diarization=None): + """Create and start processing tasks.""" + if diarization: + self.diarization = diarization + + tasks = [] + if self.args.transcription and self.online: + tasks.append(asyncio.create_task(self.transcription_processor())) + + if self.args.diarization and self.diarization: + tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) + + tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader())) + self.tasks = tasks + + return self.results_formatter() + + async def cleanup(self): + """Clean up resources when processing is complete.""" + for task in self.tasks: + task.cancel() + + try: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.ffmpeg_process.stdin.close() + self.ffmpeg_process.wait() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + + if self.args.diarization and hasattr(self, 'diarization'): + self.diarization.close() + + async def process_audio(self, message): + """Process incoming audio data.""" + try: + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + except (BrokenPipeError, AttributeError) as e: + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") + await self.restart_ffmpeg() + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() \ No newline at end of file diff --git a/formatters.py b/formatters.py deleted file mode 100644 index d48473f..0000000 --- a/formatters.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Dict, Any, List -from datetime import timedelta - -def format_time(seconds: float) -> str: - """Format seconds as HH:MM:SS.""" - return str(timedelta(seconds=int(seconds))) - -def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]: - """ - Format the shared state into a client-friendly response. - - Args: - state: Current shared state dictionary - with_diarization: Whether to include diarization formatting - - Returns: - Formatted response dictionary ready to send to client - """ - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # Default response for empty state - if not tokens: - return { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - # Process tokens to create response - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - - # Handle diarization logic - if with_diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - # Add new line or append to existing line - if speaker != previous_speaker or not lines: - lines.append({ - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - }) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - # If we have undiarized text, include it in the buffer - if undiarized_text: - combined_buffer = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer += sep + buffer_transcription - buffer_diarization = combined_buffer - - return { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } \ No newline at end of file diff --git a/state.py b/state.py deleted file mode 100644 index a4c864a..0000000 --- a/state.py +++ /dev/null @@ -1,96 +0,0 @@ -import asyncio -import logging -from time import time -from typing import List, Dict, Any, Optional -from dataclasses import dataclass, field -from timed_objects import ASRToken - -logger = logging.getLogger(__name__) - - -class SharedState: - """ - Thread-safe state manager for streaming transcription and diarization. - Handles coordination between audio processing, transcription, and diarization. - """ - - def __init__(self): - self.tokens: List[ASRToken] = [] - self.buffer_transcription: str = "" - self.buffer_diarization: str = "" - self.full_transcription: str = "" - self.end_buffer: float = 0 - self.end_attributed_speaker: float = 0 - self.lock = asyncio.Lock() - self.beg_loop: float = time() - self.sep: str = " " # Default separator - self.last_response_content: str = "" # To track changes in response - - async def update_transcription(self, new_tokens: List[ASRToken], buffer: str, - end_buffer: float, full_transcription: str, sep: str) -> None: - """Update the state with new transcription data.""" - async with self.lock: - self.tokens.extend(new_tokens) - self.buffer_transcription = buffer - self.end_buffer = end_buffer - self.full_transcription = full_transcription - self.sep = sep - - async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None: - """Update the state with new diarization data.""" - async with self.lock: - self.end_attributed_speaker = end_attributed_speaker - if buffer_diarization: - self.buffer_diarization = buffer_diarization - - async def add_dummy_token(self) -> None: - """Add a dummy token to keep the state updated even without transcription.""" - async with self.lock: - current_time = time() - self.beg_loop - dummy_token = ASRToken( - start=current_time, - end=current_time + 1, - text=".", - speaker=-1, - is_dummy=True - ) - self.tokens.append(dummy_token) - - async def get_current_state(self) -> Dict[str, Any]: - """Get the current state with calculated timing information.""" - async with self.lock: - current_time = time() - remaining_time_transcription = 0 - remaining_time_diarization = 0 - - # Calculate remaining time for transcription buffer - if self.end_buffer > 0: - remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) - - # Calculate remaining time for diarization - if self.tokens: - latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) - - return { - "tokens": self.tokens.copy(), - "buffer_transcription": self.buffer_transcription, - "buffer_diarization": self.buffer_diarization, - "end_buffer": self.end_buffer, - "end_attributed_speaker": self.end_attributed_speaker, - "sep": self.sep, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - async def reset(self) -> None: - """Reset the state to initial values.""" - async with self.lock: - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.full_transcription = "" - self.beg_loop = time() - self.last_response_content = "" \ No newline at end of file diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index cf81ff8..6c41f27 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -8,7 +8,7 @@ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr import asyncio import logging from parse_args import parse_args -from audio import AudioProcessor +from audio_processor import AudioProcessor logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger().setLevel(logging.WARNING) @@ -80,7 +80,7 @@ async def websocket_endpoint(websocket: WebSocket): logger.warning("WebSocket disconnected.") finally: websocket_task.cancel() - audio_processor.cleanup() + await audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__":