diff --git a/audio_processor.py b/audio_processor.py new file mode 100644 index 0000000..94bfbef --- /dev/null +++ b/audio_processor.py @@ -0,0 +1,406 @@ +import asyncio +import numpy as np +import ffmpeg +from time import time, sleep +import math +import logging +import traceback +from datetime import timedelta +from typing import List, Dict, Any +from timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import online_factory + +# Set up logging once +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +def format_time(seconds: float) -> str: + """Format seconds as HH:MM:SS.""" + return str(timedelta(seconds=int(seconds))) + +class AudioProcessor: + """ + Processes audio streams for transcription and diarization. + Handles audio processing, state management, and result formatting in a single class. + """ + + def __init__(self, args, asr, tokenizer): + """Initialize the audio processor with configuration, models, and state.""" + # Audio processing settings + self.args = args + self.sample_rate = 16000 + self.channels = 1 + self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) + self.bytes_per_sample = 2 + self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample + self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + + # State management + self.tokens = [] + self.buffer_transcription = "" + self.buffer_diarization = "" + self.full_transcription = "" + self.end_buffer = 0 + self.end_attributed_speaker = 0 + self.lock = asyncio.Lock() + self.beg_loop = time() + self.sep = " " # Default separator + self.last_response_content = "" + + # Models and processing + self.asr = asr + self.tokenizer = tokenizer + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.transcription_queue = asyncio.Queue() if args.transcription else None + self.diarization_queue = asyncio.Queue() if args.diarization else None + self.pcm_buffer = bytearray() + + # Initialize transcription engine if enabled + if args.transcription: + self.online = online_factory(args, asr, tokenizer) + + def convert_pcm_to_float(self, pcm_buffer): + """Convert PCM buffer in s16le format to normalized NumPy array.""" + return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 + + def start_ffmpeg_decoder(self): + """Start FFmpeg process for WebM to PCM conversion.""" + return (ffmpeg.input("pipe:0", format="webm") + .output("pipe:1", format="s16le", acodec="pcm_s16le", + ac=self.channels, ar=str(self.sample_rate)) + .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)) + + async def restart_ffmpeg(self): + """Restart the FFmpeg process after failure.""" + if self.ffmpeg_process: + try: + self.ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) + except Exception as e: + logger.warning(f"Error killing FFmpeg process: {e}") + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.pcm_buffer = bytearray() + + async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): + """Thread-safe update of transcription with new data.""" + async with self.lock: + self.tokens.extend(new_tokens) + self.buffer_transcription = buffer + self.end_buffer = end_buffer + self.full_transcription = full_transcription + self.sep = sep + + async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): + """Thread-safe update of diarization with new data.""" + async with self.lock: + self.end_attributed_speaker = end_attributed_speaker + if buffer_diarization: + self.buffer_diarization = buffer_diarization + + async def add_dummy_token(self): + """Placeholder token when no transcription is available.""" + async with self.lock: + current_time = time() - self.beg_loop + self.tokens.append(ASRToken( + start=current_time, end=current_time + 1, + text=".", speaker=-1, is_dummy=True + )) + + async def get_current_state(self): + """Get current state.""" + async with self.lock: + current_time = time() + + # Calculate remaining times + remaining_transcription = 0 + if self.end_buffer > 0: + remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) + + remaining_diarization = 0 + if self.tokens: + latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) + remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) + + return { + "tokens": self.tokens.copy(), + "buffer_transcription": self.buffer_transcription, + "buffer_diarization": self.buffer_diarization, + "end_buffer": self.end_buffer, + "end_attributed_speaker": self.end_attributed_speaker, + "sep": self.sep, + "remaining_time_transcription": remaining_transcription, + "remaining_time_diarization": remaining_diarization + } + + async def reset(self): + """Reset all state variables to initial values.""" + async with self.lock: + self.tokens = [] + self.buffer_transcription = self.buffer_diarization = "" + self.end_buffer = self.end_attributed_speaker = 0 + self.full_transcription = self.last_response_content = "" + self.beg_loop = time() + + async def ffmpeg_stdout_reader(self): + """Read audio data from FFmpeg stdout and process it.""" + loop = asyncio.get_event_loop() + beg = time() + + while True: + try: + # Calculate buffer size based on elapsed time + elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec + buffer_size = max(int(32000 * elapsed_time), 4096) + beg = time() + + # Read chunk with timeout + try: + chunk = await asyncio.wait_for( + loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size), + timeout=15.0 + ) + except asyncio.TimeoutError: + logger.warning("FFmpeg read timeout. Restarting...") + await self.restart_ffmpeg() + beg = time() + continue + + if not chunk: + logger.info("FFmpeg stdout closed.") + break + + self.pcm_buffer.extend(chunk) + + # Send to diarization if enabled + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put( + self.convert_pcm_to_float(self.pcm_buffer).copy() + ) + + # Process when we have enough data + if len(self.pcm_buffer) >= self.bytes_per_sec: + if len(self.pcm_buffer) > self.max_bytes_per_sec: + logger.warning( + f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. " + f"Consider using a smaller model." + ) + + # Process audio chunk + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) + self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] + + # Send to transcription if enabled + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) + + # Sleep if no processing is happening + if not self.args.transcription and not self.args.diarization: + await asyncio.sleep(0.1) + + except Exception as e: + logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + break + + async def transcription_processor(self): + """Process audio chunks for transcription.""" + self.full_transcription = "" + self.sep = self.online.asr.sep + + while True: + try: + pcm_array = await self.transcription_queue.get() + + logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.") + + # Process transcription + self.online.insert_audio_chunk(pcm_array) + new_tokens = self.online.process_iter() + + if new_tokens: + self.full_transcription += self.sep.join([t.text for t in new_tokens]) + + # Get buffer information + _buffer = self.online.get_buffer() + buffer = _buffer.text + end_buffer = _buffer.end if _buffer.end else ( + new_tokens[-1].end if new_tokens else 0 + ) + + # Avoid duplicating content + if buffer in self.full_transcription: + buffer = "" + + await self.update_transcription( + new_tokens, buffer, end_buffer, self.full_transcription, self.sep + ) + + except Exception as e: + logger.warning(f"Exception in transcription_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.transcription_queue.task_done() + + async def diarization_processor(self, diarization_obj): + """Process audio chunks for speaker diarization.""" + buffer_diarization = "" + + while True: + try: + pcm_array = await self.diarization_queue.get() + + # Process diarization + await diarization_obj.diarize(pcm_array) + + # Get current state and update speakers + state = await self.get_current_state() + new_end = diarization_obj.assign_speakers_to_tokens( + state["end_attributed_speaker"], state["tokens"] + ) + + await self.update_diarization(new_end, buffer_diarization) + + except Exception as e: + logger.warning(f"Exception in diarization_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.diarization_queue.task_done() + + async def results_formatter(self): + """Format processing results for output.""" + while True: + try: + # Get current state + state = await self.get_current_state() + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + sep = state["sep"] + + # Add dummy tokens if needed + if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: + await self.add_dummy_token() + sleep(0.5) + state = await self.get_current_state() + tokens = state["tokens"] + + # Format output + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + + # Process each token + for token in tokens: + speaker = token.speaker + + # Handle diarization + if self.args.diarization: + if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: + speaker = previous_speaker + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + # Group by speaker + if speaker != previous_speaker or not lines: + lines.append({ + "speaker": speaker, + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + }) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + # Handle undiarized text + if undiarized_text: + combined = sep.join(undiarized_text) + if buffer_transcription: + combined += sep + await self.update_diarization(end_attributed_speaker, combined) + buffer_diarization = combined + + # Create response object + if not lines: + lines = [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(tokens[-1].end if tokens else 0), + "diff": 0 + }] + + response = { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": state["remaining_time_transcription"], + "remaining_time_diarization": state["remaining_time_diarization"] + } + + # Only yield if content has changed + response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \ + f" | {buffer_transcription} | {buffer_diarization}" + + if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization): + yield response + self.last_response_content = response_content + + await asyncio.sleep(0.1) # Avoid overwhelming the client + + except Exception as e: + logger.warning(f"Exception in results_formatter: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + await asyncio.sleep(0.5) # Back off on error + + async def create_tasks(self, diarization=None): + """Create and start processing tasks.""" + if diarization: + self.diarization = diarization + + tasks = [] + if self.args.transcription and self.online: + tasks.append(asyncio.create_task(self.transcription_processor())) + + if self.args.diarization and self.diarization: + tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) + + tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader())) + self.tasks = tasks + + return self.results_formatter() + + async def cleanup(self): + """Clean up resources when processing is complete.""" + for task in self.tasks: + task.cancel() + + try: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.ffmpeg_process.stdin.close() + self.ffmpeg_process.wait() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + + if self.args.diarization and hasattr(self, 'diarization'): + self.diarization.close() + + async def process_audio(self, message): + """Process incoming audio data.""" + try: + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + except (BrokenPipeError, AttributeError) as e: + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") + await self.restart_ffmpeg() + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() \ No newline at end of file diff --git a/diarization/diarization_online.py b/diarization/diarization_online.py index d04c3a2..622fb15 100644 --- a/diarization/diarization_online.py +++ b/diarization/diarization_online.py @@ -5,7 +5,7 @@ import numpy as np import logging -from diart import SpeakerDiarization +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.inference import StreamingInference from diart.sources import AudioSource from timed_objects import SpeakerSegment @@ -103,8 +103,8 @@ class WebSocketAudioSource(AudioSource): class DiartDiarization: - def __init__(self, sample_rate: int, use_microphone: bool = False): - self.pipeline = SpeakerDiarization() + def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): + self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() if use_microphone: diff --git a/parse_args.py b/parse_args.py new file mode 100644 index 0000000..f201477 --- /dev/null +++ b/parse_args.py @@ -0,0 +1,52 @@ + +import argparse +from whisper_streaming_custom.whisper_online import add_shared_args + + +def parse_args(): + parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") + parser.add_argument( + "--host", + type=str, + default="localhost", + help="The host address to bind the server to.", + ) + parser.add_argument( + "--port", type=int, default=8000, help="The port number to bind the server to." + ) + parser.add_argument( + "--warmup-file", + type=str, + default=None, + dest="warmup_file", + help=""" + The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. + If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav. + If False, no warmup is performed. + """, + ) + + parser.add_argument( + "--confidence-validation", + type=bool, + default=False, + help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", + ) + + parser.add_argument( + "--diarization", + type=bool, + default=True, + help="Whether to enable speaker diarization.", + ) + + parser.add_argument( + "--transcription", + type=bool, + default=True, + help="To disable to only see live diarization results.", + ) + + add_shared_args(parser) + args = parser.parse_args() + return args \ No newline at end of file diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index 24597bc..6c41f27 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -1,164 +1,22 @@ -import io -import argparse -import asyncio -import numpy as np -import ffmpeg -from time import time, sleep from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware -from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr -from timed_objects import ASRToken - -import math +from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr +import asyncio import logging -from datetime import timedelta -import traceback - -def format_time(seconds): - return str(timedelta(seconds=int(seconds))) - +from parse_args import parse_args +from audio_processor import AudioProcessor logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger().setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -##### LOAD ARGS ##### +args = parse_args() -parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") -parser.add_argument( - "--host", - type=str, - default="localhost", - help="The host address to bind the server to.", -) -parser.add_argument( - "--port", type=int, default=8000, help="The port number to bind the server to." -) -parser.add_argument( - "--warmup-file", - type=str, - default=None, - dest="warmup_file", - help=""" - The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. - If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav. - If False, no warmup is performed. - """, -) - -parser.add_argument( - "--confidence-validation", - type=bool, - default=False, - help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.", -) - -parser.add_argument( - "--diarization", - type=bool, - default=False, - help="Whether to enable speaker diarization.", -) - -parser.add_argument( - "--transcription", - type=bool, - default=True, - help="To disable to only see live diarization results.", -) - -add_shared_args(parser) -args = parser.parse_args() - -SAMPLE_RATE = 16000 -CHANNELS = 1 -SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size) -BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample -BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE -MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz - - -class SharedState: - def __init__(self): - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.full_transcription = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.lock = asyncio.Lock() - self.beg_loop = time() - self.sep = " " # Default separator - self.last_response_content = "" # To track changes in response - - async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): - async with self.lock: - self.tokens.extend(new_tokens) - self.buffer_transcription = buffer - self.end_buffer = end_buffer - self.full_transcription = full_transcription - self.sep = sep - - async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): - async with self.lock: - self.end_attributed_speaker = end_attributed_speaker - if buffer_diarization: - self.buffer_diarization = buffer_diarization - - async def add_dummy_token(self): - async with self.lock: - current_time = time() - self.beg_loop - dummy_token = ASRToken( - start=current_time, - end=current_time + 1, - text=".", - speaker=-1, - is_dummy=True - ) - self.tokens.append(dummy_token) - - async def get_current_state(self): - async with self.lock: - current_time = time() - remaining_time_transcription = 0 - remaining_time_diarization = 0 - - # Calculate remaining time for transcription buffer - if self.end_buffer > 0: - remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) - - # Calculate remaining time for diarization - remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2)) - - return { - "tokens": self.tokens.copy(), - "buffer_transcription": self.buffer_transcription, - "buffer_diarization": self.buffer_diarization, - "end_buffer": self.end_buffer, - "end_attributed_speaker": self.end_attributed_speaker, - "sep": self.sep, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - async def reset(self): - """Reset the state.""" - async with self.lock: - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.full_transcription = "" - self.beg_loop = time() - self.last_response_content = "" - -##### LOAD APP ##### @asynccontextmanager async def lifespan(app: FastAPI): @@ -171,7 +29,7 @@ async def lifespan(app: FastAPI): if args.diarization: from diarization.diarization_online import DiartDiarization - diarization = DiartDiarization(SAMPLE_RATE) + diarization = DiartDiarization() else : diarization = None yield @@ -190,328 +48,39 @@ app.add_middleware( with open("web/live_transcription.html", "r", encoding="utf-8") as f: html = f.read() -def convert_pcm_to_float(pcm_buffer): - """ - Converts a PCM buffer in s16le format to a normalized NumPy array. - Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format - Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0 - """ - pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) - / 32768.0) - return pcm_array - -async def start_ffmpeg_decoder(): - """ - Start an FFmpeg process in async streaming mode that reads WebM from stdin - and outputs raw s16le PCM on stdout. Returns the process object. - """ - process = ( - ffmpeg.input("pipe:0", format="webm") - .output( - "pipe:1", - format="s16le", - acodec="pcm_s16le", - ac=CHANNELS, - ar=str(SAMPLE_RATE), - ) - .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) - ) - return process - -async def transcription_processor(shared_state, pcm_queue, online): - full_transcription = "" - sep = online.asr.sep - - while True: - try: - pcm_array = await pcm_queue.get() - - logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.") - - # Process transcription - online.insert_audio_chunk(pcm_array) - new_tokens = online.process_iter() - - if new_tokens: - full_transcription += sep.join([t.text for t in new_tokens]) - - _buffer = online.get_buffer() - buffer = _buffer.text - end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) - - if buffer in full_transcription: - buffer = "" - - await shared_state.update_transcription( - new_tokens, buffer, end_buffer, full_transcription, sep) - - except Exception as e: - logger.warning(f"Exception in transcription_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - pcm_queue.task_done() - -async def diarization_processor(shared_state, pcm_queue, diarization_obj): - buffer_diarization = "" - - while True: - try: - pcm_array = await pcm_queue.get() - - # Process diarization - await diarization_obj.diarize(pcm_array) - - # Get current state - state = await shared_state.get_current_state() - tokens = state["tokens"] - end_attributed_speaker = state["end_attributed_speaker"] - - # Update speaker information - new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens( - end_attributed_speaker, tokens) - - await shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization) - - except Exception as e: - logger.warning(f"Exception in diarization_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - pcm_queue.task_done() - -async def results_formatter(shared_state, websocket): - while True: - try: - # Get the current state - state = await shared_state.get_current_state() - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # If diarization is enabled but no transcription, add dummy tokens periodically - if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization: - await shared_state.add_dummy_token() - sleep(0.5) - state = await shared_state.get_current_state() - tokens = state["tokens"] - # Process tokens to create response - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - if args.diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - if speaker != previous_speaker or not lines: - lines.append( - { - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - } - ) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - if undiarized_text: - combined_buffer_diarization = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer_diarization += sep - await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) - buffer_diarization = combined_buffer_diarization - - if lines: - response = { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - else: - response = { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(tokens[-1].end) if tokens else format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - - } - - response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization - - if response_content != shared_state.last_response_content: - if lines or buffer_transcription or buffer_diarization: - await websocket.send_json(response) - shared_state.last_response_content = response_content - - # Add a small delay to avoid overwhelming the client - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in results_formatter: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - await asyncio.sleep(0.5) # Back off on error - -##### ENDPOINTS ##### - @app.get("/") async def get(): return HTMLResponse(html) + +async def handle_websocket_results(websocket, results_generator): + """Consumes results from the audio processor and sends them via WebSocket.""" + try: + async for response in results_generator: + await websocket.send_json(response) + except Exception as e: + logger.warning(f"Error in WebSocket results handler: {e}") + + @app.websocket("/asr") async def websocket_endpoint(websocket: WebSocket): + audio_processor = AudioProcessor(args, asr, tokenizer) + await websocket.accept() logger.info("WebSocket connection opened.") + + results_generator = await audio_processor.create_tasks(diarization) + websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) - ffmpeg_process = None - pcm_buffer = bytearray() - shared_state = SharedState() - - transcription_queue = asyncio.Queue() if args.transcription else None - diarization_queue = asyncio.Queue() if args.diarization else None - - online = None - - async def restart_ffmpeg(): - nonlocal ffmpeg_process, online, pcm_buffer - if ffmpeg_process: - try: - ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait) - except Exception as e: - logger.warning(f"Error killing FFmpeg process: {e}") - ffmpeg_process = await start_ffmpeg_decoder() - pcm_buffer = bytearray() - - if args.transcription: - online = online_factory(args, asr, tokenizer) - - await shared_state.reset() - logger.info("FFmpeg process started.") - - await restart_ffmpeg() - - tasks = [] - if args.transcription and online: - tasks.append(asyncio.create_task( - transcription_processor(shared_state, transcription_queue, online))) - if args.diarization and diarization: - tasks.append(asyncio.create_task( - diarization_processor(shared_state, diarization_queue, diarization))) - formatter_task = asyncio.create_task(results_formatter(shared_state, websocket)) - tasks.append(formatter_task) - - async def ffmpeg_stdout_reader(): - nonlocal ffmpeg_process, pcm_buffer - loop = asyncio.get_event_loop() - beg = time() - - while True: - try: - elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec - ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) - beg = time() - - # Read chunk with timeout - try: - chunk = await asyncio.wait_for( - loop.run_in_executor( - None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration - ), - timeout=15.0 - ) - except asyncio.TimeoutError: - logger.warning("FFmpeg read timeout. Restarting...") - await restart_ffmpeg() - beg = time() - continue # Skip processing and read from new process - - if not chunk: - logger.info("FFmpeg stdout closed.") - break - pcm_buffer.extend(chunk) - - if args.diarization and diarization_queue: - await diarization_queue.put(convert_pcm_to_float(pcm_buffer).copy()) - - if len(pcm_buffer) >= BYTES_PER_SEC: - if len(pcm_buffer) > MAX_BYTES_PER_SEC: - logger.warning( - f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds. - The model probably struggles to keep up. Consider using a smaller model. - """) - - pcm_array = convert_pcm_to_float(pcm_buffer[:MAX_BYTES_PER_SEC]) - pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:] - - if args.transcription and transcription_queue: - await transcription_queue.put(pcm_array.copy()) - - - if not args.transcription and not args.diarization: - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - break - - logger.info("Exiting ffmpeg_stdout_reader...") - - stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) - tasks.append(stdout_reader_task) try: while True: - # Receive incoming WebM audio chunks from the client message = await websocket.receive_bytes() - try: - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() - except (BrokenPipeError, AttributeError) as e: - logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - await restart_ffmpeg() - ffmpeg_process.stdin.write(message) - ffmpeg_process.stdin.flush() + await audio_processor.process_audio(message) except WebSocketDisconnect: logger.warning("WebSocket disconnected.") finally: - for task in tasks: - task.cancel() - - try: - await asyncio.gather(*tasks, return_exceptions=True) - ffmpeg_process.stdin.close() - ffmpeg_process.wait() - except Exception as e: - logger.warning(f"Error during cleanup: {e}") - - if args.diarization and diarization: - diarization.close() - + websocket_task.cancel() + await audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__": diff --git a/whisper_streaming_custom/whisper_online.py b/whisper_streaming_custom/whisper_online.py index 617f05b..d7263ac 100644 --- a/whisper_streaming_custom/whisper_online.py +++ b/whisper_streaming_custom/whisper_online.py @@ -71,7 +71,7 @@ def add_shared_args(parser): parser.add_argument( "--min-chunk-size", type=float, - default=1.0, + default=0.5, help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.", ) parser.add_argument(