diff --git a/audio.py b/audio.py deleted file mode 100644 index 2436685..0000000 --- a/audio.py +++ /dev/null @@ -1,335 +0,0 @@ -import asyncio -import numpy as np -import ffmpeg -from time import time, sleep - - -from whisper_streaming_custom.whisper_online import online_factory -import math -import logging -import traceback -from state import SharedState -from formatters import format_time - - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logging.getLogger().setLevel(logging.WARNING) -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -class AudioProcessor: - - def __init__(self, args, asr, tokenizer): - self.args = args - self.sample_rate = 16000 - self.channels = 1 - self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) - self.bytes_per_sample = 2 - self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample - self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz - - - self.shared_state = SharedState() - self.asr = asr - self.tokenizer = tokenizer - - self.ffmpeg_process = self.start_ffmpeg_decoder() - - self.transcription_queue = asyncio.Queue() if self.args.transcription else None - self.diarization_queue = asyncio.Queue() if self.args.diarization else None - - self.pcm_buffer = bytearray() - if self.args.transcription: - self.online = online_factory(self.args, self.asr, self.tokenizer) - - - - def convert_pcm_to_float(self, pcm_buffer): - """ - Converts a PCM buffer in s16le format to a normalized NumPy array. - Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format - Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0 - """ - pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) - / 32768.0) - return pcm_array - - def start_ffmpeg_decoder(self): - """ - Start an FFmpeg process in async streaming mode that reads WebM from stdin - and outputs raw s16le PCM on stdout. Returns the process object. - """ - process = ( - ffmpeg.input("pipe:0", format="webm") - .output( - "pipe:1", - format="s16le", - acodec="pcm_s16le", - ac=self.channels, - ar=str(self.sample_rate), - ) - .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) - ) - return process - - async def restart_ffmpeg(self): - if self.ffmpeg_process: - try: - self.ffmpeg_process.kill() - await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) - except Exception as e: - logger.warning(f"Error killing FFmpeg process: {e}") - self.ffmpeg_process = self.start_ffmpeg_decoder() - self.pcm_buffer = bytearray() - - async def ffmpeg_stdout_reader(self): - loop = asyncio.get_event_loop() - beg = time() - - while True: - try: - elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec - ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096) - beg = time() - - # Read chunk with timeout - try: - chunk = await asyncio.wait_for( - loop.run_in_executor( - None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration - ), - timeout=15.0 - ) - except asyncio.TimeoutError: - logger.warning("FFmpeg read timeout. Restarting...") - await self.restart_ffmpeg() - beg = time() - continue # Skip processing and read from new process - - if not chunk: - logger.info("FFmpeg stdout closed.") - break - self.pcm_buffer.extend(chunk) - - if self.args.diarization and self.diarization_queue: - await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy()) - - if len(self.pcm_buffer) >= self.bytes_per_sec: - if len(self.pcm_buffer) > self.max_bytes_per_sec: - logger.warning( - f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds. - The model probably struggles to keep up. Consider using a smaller model. - """) - - pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) - self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] - - if self.args.transcription and self.transcription_queue: - await self.transcription_queue.put(pcm_array.copy()) - - - if not self.args.transcription and not self.args.diarization: - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - break - logger.info("Exiting ffmpeg_stdout_reader...") - - async def transcription_processor(self): - full_transcription = "" - sep = self.online.asr.sep - - while True: - try: - pcm_array = await self.transcription_queue.get() - - logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.") - - # Process transcription - self.online.insert_audio_chunk(pcm_array) - new_tokens = self.online.process_iter() - - if new_tokens: - full_transcription += sep.join([t.text for t in new_tokens]) - - _buffer = self.online.get_buffer() - buffer = _buffer.text - end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0) - - if buffer in full_transcription: - buffer = "" - - await self.shared_state.update_transcription( - new_tokens, buffer, end_buffer, full_transcription, sep) - - except Exception as e: - logger.warning(f"Exception in transcription_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - self.transcription_queue.task_done() - - - async def diarization_processor(self, diarization_obj): - buffer_diarization = "" - - while True: - try: - pcm_array = await self.diarization_queue.get() - - # Process diarization - await diarization_obj.diarize(pcm_array) - - # Get current state - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - end_attributed_speaker = state["end_attributed_speaker"] - - # Update speaker information - new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens( - end_attributed_speaker, tokens) - - await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization) - - except Exception as e: - logger.warning(f"Exception in diarization_processor: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - finally: - self.diarization_queue.task_done() - - async def results_formatter(self): - while True: - try: - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # If diarization is enabled but no transcription, add dummy tokens periodically - if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: - await self.shared_state.add_dummy_token() - sleep(0.5) - state = await self.shared_state.get_current_state() - tokens = state["tokens"] - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - if self.args.diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - if speaker != previous_speaker or not lines: - lines.append( - { - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - } - ) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - if undiarized_text: - combined_buffer_diarization = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer_diarization += sep - await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization) - buffer_diarization = combined_buffer_diarization - - if lines: - response = { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - else: - response = { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(tokens[-1].end) if tokens else format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization - - if response_content != self.shared_state.last_response_content: - if lines or buffer_transcription or buffer_diarization: - yield response - self.shared_state.last_response_content = response_content - - #small delay to avoid overwhelming the client - await asyncio.sleep(0.1) - - except Exception as e: - logger.warning(f"Exception in results_formatter: {e}") - logger.warning(f"Traceback: {traceback.format_exc()}") - await asyncio.sleep(0.5) # Back off on error - - async def create_tasks(self, diarization=None): - if diarization: - self.diarization = diarization - - tasks = [] - if self.args.transcription and self.online: - tasks.append(asyncio.create_task(self.transcription_processor())) - if self.args.diarization and self.diarization: - tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) - - stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) - tasks.append(stdout_reader_task) - - self.tasks = tasks - - return self.results_formatter() - - async def cleanup(self): - for task in self.tasks: - task.cancel() - try: - await asyncio.gather(*self.tasks, return_exceptions=True) - self.ffmpeg_process.stdin.close() - self.ffmpeg_process.wait() - except Exception as e: - logger.warning(f"Error during cleanup: {e}") - if self.args.diarization and self.diarization: - self.diarization.close() - - async def process_audio(self, message): - try: - self.ffmpeg_process.stdin.write(message) - self.ffmpeg_process.stdin.flush() - except (BrokenPipeError, AttributeError) as e: - logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") - await self.restart_ffmpeg() - self.ffmpeg_process.stdin.write(message) - self.ffmpeg_process.stdin.flush() - - \ No newline at end of file diff --git a/audio_processor.py b/audio_processor.py new file mode 100644 index 0000000..94bfbef --- /dev/null +++ b/audio_processor.py @@ -0,0 +1,406 @@ +import asyncio +import numpy as np +import ffmpeg +from time import time, sleep +import math +import logging +import traceback +from datetime import timedelta +from typing import List, Dict, Any +from timed_objects import ASRToken +from whisper_streaming_custom.whisper_online import online_factory + +# Set up logging once +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +def format_time(seconds: float) -> str: + """Format seconds as HH:MM:SS.""" + return str(timedelta(seconds=int(seconds))) + +class AudioProcessor: + """ + Processes audio streams for transcription and diarization. + Handles audio processing, state management, and result formatting in a single class. + """ + + def __init__(self, args, asr, tokenizer): + """Initialize the audio processor with configuration, models, and state.""" + # Audio processing settings + self.args = args + self.sample_rate = 16000 + self.channels = 1 + self.samples_per_sec = int(self.sample_rate * args.min_chunk_size) + self.bytes_per_sample = 2 + self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample + self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz + + # State management + self.tokens = [] + self.buffer_transcription = "" + self.buffer_diarization = "" + self.full_transcription = "" + self.end_buffer = 0 + self.end_attributed_speaker = 0 + self.lock = asyncio.Lock() + self.beg_loop = time() + self.sep = " " # Default separator + self.last_response_content = "" + + # Models and processing + self.asr = asr + self.tokenizer = tokenizer + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.transcription_queue = asyncio.Queue() if args.transcription else None + self.diarization_queue = asyncio.Queue() if args.diarization else None + self.pcm_buffer = bytearray() + + # Initialize transcription engine if enabled + if args.transcription: + self.online = online_factory(args, asr, tokenizer) + + def convert_pcm_to_float(self, pcm_buffer): + """Convert PCM buffer in s16le format to normalized NumPy array.""" + return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 + + def start_ffmpeg_decoder(self): + """Start FFmpeg process for WebM to PCM conversion.""" + return (ffmpeg.input("pipe:0", format="webm") + .output("pipe:1", format="s16le", acodec="pcm_s16le", + ac=self.channels, ar=str(self.sample_rate)) + .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)) + + async def restart_ffmpeg(self): + """Restart the FFmpeg process after failure.""" + if self.ffmpeg_process: + try: + self.ffmpeg_process.kill() + await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) + except Exception as e: + logger.warning(f"Error killing FFmpeg process: {e}") + self.ffmpeg_process = self.start_ffmpeg_decoder() + self.pcm_buffer = bytearray() + + async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): + """Thread-safe update of transcription with new data.""" + async with self.lock: + self.tokens.extend(new_tokens) + self.buffer_transcription = buffer + self.end_buffer = end_buffer + self.full_transcription = full_transcription + self.sep = sep + + async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): + """Thread-safe update of diarization with new data.""" + async with self.lock: + self.end_attributed_speaker = end_attributed_speaker + if buffer_diarization: + self.buffer_diarization = buffer_diarization + + async def add_dummy_token(self): + """Placeholder token when no transcription is available.""" + async with self.lock: + current_time = time() - self.beg_loop + self.tokens.append(ASRToken( + start=current_time, end=current_time + 1, + text=".", speaker=-1, is_dummy=True + )) + + async def get_current_state(self): + """Get current state.""" + async with self.lock: + current_time = time() + + # Calculate remaining times + remaining_transcription = 0 + if self.end_buffer > 0: + remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) + + remaining_diarization = 0 + if self.tokens: + latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) + remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) + + return { + "tokens": self.tokens.copy(), + "buffer_transcription": self.buffer_transcription, + "buffer_diarization": self.buffer_diarization, + "end_buffer": self.end_buffer, + "end_attributed_speaker": self.end_attributed_speaker, + "sep": self.sep, + "remaining_time_transcription": remaining_transcription, + "remaining_time_diarization": remaining_diarization + } + + async def reset(self): + """Reset all state variables to initial values.""" + async with self.lock: + self.tokens = [] + self.buffer_transcription = self.buffer_diarization = "" + self.end_buffer = self.end_attributed_speaker = 0 + self.full_transcription = self.last_response_content = "" + self.beg_loop = time() + + async def ffmpeg_stdout_reader(self): + """Read audio data from FFmpeg stdout and process it.""" + loop = asyncio.get_event_loop() + beg = time() + + while True: + try: + # Calculate buffer size based on elapsed time + elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec + buffer_size = max(int(32000 * elapsed_time), 4096) + beg = time() + + # Read chunk with timeout + try: + chunk = await asyncio.wait_for( + loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size), + timeout=15.0 + ) + except asyncio.TimeoutError: + logger.warning("FFmpeg read timeout. Restarting...") + await self.restart_ffmpeg() + beg = time() + continue + + if not chunk: + logger.info("FFmpeg stdout closed.") + break + + self.pcm_buffer.extend(chunk) + + # Send to diarization if enabled + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put( + self.convert_pcm_to_float(self.pcm_buffer).copy() + ) + + # Process when we have enough data + if len(self.pcm_buffer) >= self.bytes_per_sec: + if len(self.pcm_buffer) > self.max_bytes_per_sec: + logger.warning( + f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. " + f"Consider using a smaller model." + ) + + # Process audio chunk + pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) + self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] + + # Send to transcription if enabled + if self.args.transcription and self.transcription_queue: + await self.transcription_queue.put(pcm_array.copy()) + + # Sleep if no processing is happening + if not self.args.transcription and not self.args.diarization: + await asyncio.sleep(0.1) + + except Exception as e: + logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + break + + async def transcription_processor(self): + """Process audio chunks for transcription.""" + self.full_transcription = "" + self.sep = self.online.asr.sep + + while True: + try: + pcm_array = await self.transcription_queue.get() + + logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.") + + # Process transcription + self.online.insert_audio_chunk(pcm_array) + new_tokens = self.online.process_iter() + + if new_tokens: + self.full_transcription += self.sep.join([t.text for t in new_tokens]) + + # Get buffer information + _buffer = self.online.get_buffer() + buffer = _buffer.text + end_buffer = _buffer.end if _buffer.end else ( + new_tokens[-1].end if new_tokens else 0 + ) + + # Avoid duplicating content + if buffer in self.full_transcription: + buffer = "" + + await self.update_transcription( + new_tokens, buffer, end_buffer, self.full_transcription, self.sep + ) + + except Exception as e: + logger.warning(f"Exception in transcription_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.transcription_queue.task_done() + + async def diarization_processor(self, diarization_obj): + """Process audio chunks for speaker diarization.""" + buffer_diarization = "" + + while True: + try: + pcm_array = await self.diarization_queue.get() + + # Process diarization + await diarization_obj.diarize(pcm_array) + + # Get current state and update speakers + state = await self.get_current_state() + new_end = diarization_obj.assign_speakers_to_tokens( + state["end_attributed_speaker"], state["tokens"] + ) + + await self.update_diarization(new_end, buffer_diarization) + + except Exception as e: + logger.warning(f"Exception in diarization_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + finally: + self.diarization_queue.task_done() + + async def results_formatter(self): + """Format processing results for output.""" + while True: + try: + # Get current state + state = await self.get_current_state() + tokens = state["tokens"] + buffer_transcription = state["buffer_transcription"] + buffer_diarization = state["buffer_diarization"] + end_attributed_speaker = state["end_attributed_speaker"] + sep = state["sep"] + + # Add dummy tokens if needed + if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: + await self.add_dummy_token() + sleep(0.5) + state = await self.get_current_state() + tokens = state["tokens"] + + # Format output + previous_speaker = -1 + lines = [] + last_end_diarized = 0 + undiarized_text = [] + + # Process each token + for token in tokens: + speaker = token.speaker + + # Handle diarization + if self.args.diarization: + if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: + undiarized_text.append(token.text) + continue + elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: + speaker = previous_speaker + if speaker not in [-1, 0]: + last_end_diarized = max(token.end, last_end_diarized) + + # Group by speaker + if speaker != previous_speaker or not lines: + lines.append({ + "speaker": speaker, + "text": token.text, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + }) + previous_speaker = speaker + elif token.text: # Only append if text isn't empty + lines[-1]["text"] += sep + token.text + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + # Handle undiarized text + if undiarized_text: + combined = sep.join(undiarized_text) + if buffer_transcription: + combined += sep + await self.update_diarization(end_attributed_speaker, combined) + buffer_diarization = combined + + # Create response object + if not lines: + lines = [{ + "speaker": 1, + "text": "", + "beg": format_time(0), + "end": format_time(tokens[-1].end if tokens else 0), + "diff": 0 + }] + + response = { + "lines": lines, + "buffer_transcription": buffer_transcription, + "buffer_diarization": buffer_diarization, + "remaining_time_transcription": state["remaining_time_transcription"], + "remaining_time_diarization": state["remaining_time_diarization"] + } + + # Only yield if content has changed + response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \ + f" | {buffer_transcription} | {buffer_diarization}" + + if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization): + yield response + self.last_response_content = response_content + + await asyncio.sleep(0.1) # Avoid overwhelming the client + + except Exception as e: + logger.warning(f"Exception in results_formatter: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + await asyncio.sleep(0.5) # Back off on error + + async def create_tasks(self, diarization=None): + """Create and start processing tasks.""" + if diarization: + self.diarization = diarization + + tasks = [] + if self.args.transcription and self.online: + tasks.append(asyncio.create_task(self.transcription_processor())) + + if self.args.diarization and self.diarization: + tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) + + tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader())) + self.tasks = tasks + + return self.results_formatter() + + async def cleanup(self): + """Clean up resources when processing is complete.""" + for task in self.tasks: + task.cancel() + + try: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.ffmpeg_process.stdin.close() + self.ffmpeg_process.wait() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + + if self.args.diarization and hasattr(self, 'diarization'): + self.diarization.close() + + async def process_audio(self, message): + """Process incoming audio data.""" + try: + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() + except (BrokenPipeError, AttributeError) as e: + logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") + await self.restart_ffmpeg() + self.ffmpeg_process.stdin.write(message) + self.ffmpeg_process.stdin.flush() \ No newline at end of file diff --git a/formatters.py b/formatters.py deleted file mode 100644 index d48473f..0000000 --- a/formatters.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Dict, Any, List -from datetime import timedelta - -def format_time(seconds: float) -> str: - """Format seconds as HH:MM:SS.""" - return str(timedelta(seconds=int(seconds))) - -def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]: - """ - Format the shared state into a client-friendly response. - - Args: - state: Current shared state dictionary - with_diarization: Whether to include diarization formatting - - Returns: - Formatted response dictionary ready to send to client - """ - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - remaining_time_transcription = state["remaining_time_transcription"] - remaining_time_diarization = state["remaining_time_diarization"] - sep = state["sep"] - - # Default response for empty state - if not tokens: - return { - "lines": [{ - "speaker": 1, - "text": "", - "beg": format_time(0), - "end": format_time(0), - "diff": 0 - }], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - # Process tokens to create response - previous_speaker = -1 - lines = [] - last_end_diarized = 0 - undiarized_text = [] - - for token in tokens: - speaker = token.speaker - - # Handle diarization logic - if with_diarization: - if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker: - undiarized_text.append(token.text) - continue - elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker: - speaker = previous_speaker - - if speaker not in [-1, 0]: - last_end_diarized = max(token.end, last_end_diarized) - - # Add new line or append to existing line - if speaker != previous_speaker or not lines: - lines.append({ - "speaker": speaker, - "text": token.text, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - }) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) - - # If we have undiarized text, include it in the buffer - if undiarized_text: - combined_buffer = sep.join(undiarized_text) - if buffer_transcription: - combined_buffer += sep + buffer_transcription - buffer_diarization = combined_buffer - - return { - "lines": lines, - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } \ No newline at end of file diff --git a/state.py b/state.py deleted file mode 100644 index a4c864a..0000000 --- a/state.py +++ /dev/null @@ -1,96 +0,0 @@ -import asyncio -import logging -from time import time -from typing import List, Dict, Any, Optional -from dataclasses import dataclass, field -from timed_objects import ASRToken - -logger = logging.getLogger(__name__) - - -class SharedState: - """ - Thread-safe state manager for streaming transcription and diarization. - Handles coordination between audio processing, transcription, and diarization. - """ - - def __init__(self): - self.tokens: List[ASRToken] = [] - self.buffer_transcription: str = "" - self.buffer_diarization: str = "" - self.full_transcription: str = "" - self.end_buffer: float = 0 - self.end_attributed_speaker: float = 0 - self.lock = asyncio.Lock() - self.beg_loop: float = time() - self.sep: str = " " # Default separator - self.last_response_content: str = "" # To track changes in response - - async def update_transcription(self, new_tokens: List[ASRToken], buffer: str, - end_buffer: float, full_transcription: str, sep: str) -> None: - """Update the state with new transcription data.""" - async with self.lock: - self.tokens.extend(new_tokens) - self.buffer_transcription = buffer - self.end_buffer = end_buffer - self.full_transcription = full_transcription - self.sep = sep - - async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None: - """Update the state with new diarization data.""" - async with self.lock: - self.end_attributed_speaker = end_attributed_speaker - if buffer_diarization: - self.buffer_diarization = buffer_diarization - - async def add_dummy_token(self) -> None: - """Add a dummy token to keep the state updated even without transcription.""" - async with self.lock: - current_time = time() - self.beg_loop - dummy_token = ASRToken( - start=current_time, - end=current_time + 1, - text=".", - speaker=-1, - is_dummy=True - ) - self.tokens.append(dummy_token) - - async def get_current_state(self) -> Dict[str, Any]: - """Get the current state with calculated timing information.""" - async with self.lock: - current_time = time() - remaining_time_transcription = 0 - remaining_time_diarization = 0 - - # Calculate remaining time for transcription buffer - if self.end_buffer > 0: - remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) - - # Calculate remaining time for diarization - if self.tokens: - latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) - - return { - "tokens": self.tokens.copy(), - "buffer_transcription": self.buffer_transcription, - "buffer_diarization": self.buffer_diarization, - "end_buffer": self.end_buffer, - "end_attributed_speaker": self.end_attributed_speaker, - "sep": self.sep, - "remaining_time_transcription": remaining_time_transcription, - "remaining_time_diarization": remaining_time_diarization - } - - async def reset(self) -> None: - """Reset the state to initial values.""" - async with self.lock: - self.tokens = [] - self.buffer_transcription = "" - self.buffer_diarization = "" - self.end_buffer = 0 - self.end_attributed_speaker = 0 - self.full_transcription = "" - self.beg_loop = time() - self.last_response_content = "" \ No newline at end of file diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index cf81ff8..6c41f27 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -8,7 +8,7 @@ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr import asyncio import logging from parse_args import parse_args -from audio import AudioProcessor +from audio_processor import AudioProcessor logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.getLogger().setLevel(logging.WARNING) @@ -80,7 +80,7 @@ async def websocket_endpoint(websocket: WebSocket): logger.warning("WebSocket disconnected.") finally: websocket_task.cancel() - audio_processor.cleanup() + await audio_processor.cleanup() logger.info("WebSocket endpoint cleaned up.") if __name__ == "__main__":