From cd160caaa1be7fa8f7ef89e625bfc50498c968a4 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 13 Sep 2025 22:06:00 +0200 Subject: [PATCH] asyncio.to_thread for transcription and translation --- whisperlivekit/audio_processor.py | 125 +++++++++++------------------ whisperlivekit/basic_server.py | 2 +- whisperlivekit/results_formater.py | 15 ++-- whisperlivekit/timed_objects.py | 38 ++++++++- 4 files changed, 90 insertions(+), 90 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 92a3098..525c4fc 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -4,7 +4,7 @@ from time import time, sleep import math import logging import traceback -from whisperlivekit.timed_objects import ASRToken, Silence, Line +from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.silero_vad_iterator import FixedVADIterator @@ -68,7 +68,7 @@ class AudioProcessor: self.lock = asyncio.Lock() self.beg_loop = None #to deal with a potential little lag at the websocket initialization, this is now set in process_audio self.sep = " " # Default separator - self.last_response_content = "" + self.last_response_content = FrontData() # Models and processing self.asr = models.asr @@ -103,7 +103,8 @@ class AudioProcessor: self.all_tasks_for_cleanup = [] if self.args.transcription: - self.online = online_factory(self.args, models.asr, models.tokenizer) + self.online = online_factory(self.args, models.asr, models.tokenizer) + self.sep = self.online.asr.sep if self.args.diarization: self.diarization = online_diarization_factory(self.args, models.diarization_model) if self.args.target_language: @@ -113,13 +114,12 @@ class AudioProcessor: """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 - async def update_transcription(self, new_tokens, buffer, end_buffer, sep): + async def update_transcription(self, new_tokens, buffer, end_buffer): """Thread-safe update of transcription with new data.""" async with self.lock: self.tokens.extend(new_tokens) self.buffer_transcription = buffer self.end_buffer = end_buffer - self.sep = sep async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): """Thread-safe update of diarization with new data.""" @@ -152,17 +152,16 @@ class AudioProcessor: latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 1)) - return { - "tokens": self.tokens.copy(), - "translated_segments": self.translated_segments.copy(), - "buffer_transcription": self.buffer_transcription, - "buffer_diarization": self.buffer_diarization, - "end_buffer": self.end_buffer, - "end_attributed_speaker": self.end_attributed_speaker, - "sep": self.sep, - "remaining_time_transcription": remaining_transcription, - "remaining_time_diarization": remaining_diarization - } + return State( + tokens=self.tokens.copy(), + translated_segments=self.translated_segments.copy(), + buffer_transcription=self.buffer_transcription, + buffer_diarization=self.buffer_diarization, + end_buffer=self.end_buffer, + end_attributed_speaker=self.end_attributed_speaker, + remaining_time_transcription=remaining_transcription, + remaining_time_diarization=remaining_diarization + ) async def reset(self): """Reset all state variables to initial values.""" @@ -236,7 +235,6 @@ class AudioProcessor: async def transcription_processor(self): """Process audio chunks for transcription.""" - self.sep = self.online.asr.sep cumulative_pcm_duration_stream_time = 0.0 while True: @@ -276,7 +274,7 @@ class AudioProcessor: stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) - new_tokens, current_audio_processed_upto = self.online.process_iter() + new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter) # Get buffer information _buffer_transcript_obj = self.online.get_buffer() @@ -300,7 +298,7 @@ class AudioProcessor: new_end_buffer = max(candidate_end_times) await self.update_transcription( - new_tokens, buffer_text, new_end_buffer, self.sep + new_tokens, buffer_text, new_end_buffer ) if new_tokens and self.args.target_language and self.translation_queue: @@ -385,7 +383,7 @@ class AudioProcessor: tokens_to_process.append(additional_token) if tokens_to_process: online_translation.insert_tokens(tokens_to_process) - self.translated_segments = online_translation.process() + self.translated_segments = await asyncio.to_thread(online_translation.process) self.translation_queue.task_done() for _ in additional_tokens: @@ -407,39 +405,25 @@ class AudioProcessor: async def results_formatter(self): """Format processing results for output.""" - last_sent_trans = None - last_sent_diar = None while True: try: ffmpeg_state = await self.ffmpeg_manager.get_state() if ffmpeg_state == FFmpegState.FAILED and self._ffmpeg_error: - yield { - "status": "error", - "error": f"FFmpeg error: {self._ffmpeg_error}", - "lines": [], - "buffer_transcription": "", - "buffer_diarization": "", - "remaining_time_transcription": 0, - "remaining_time_diarization": 0 - } + yield FrontData( + status="error", + error=f"FFmpeg error: {self._ffmpeg_error}" + ) self._ffmpeg_error = None await asyncio.sleep(1) continue - # Get current state state = await self.get_current_state() - tokens = state["tokens"] - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - sep = state["sep"] # Add dummy tokens if needed - if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: + if (not state.tokens or state.tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: await self.add_dummy_token() sleep(0.5) state = await self.get_current_state() - tokens = state["tokens"] # Format output lines, undiarized_text, buffer_transcription, buffer_diarization = format_output( @@ -447,18 +431,19 @@ class AudioProcessor: self.silence, current_time = time() - self.beg_loop if self.beg_loop else None, args = self.args, - debug = self.debug + debug = self.debug, + sep=self.sep ) # Handle undiarized text if undiarized_text: - combined = sep.join(undiarized_text) + combined = self.sep.join(undiarized_text) if buffer_transcription: - combined += sep - await self.update_diarization(end_attributed_speaker, combined) + combined += self.sep + await self.update_diarization(state.end_attributed_speaker, combined) buffer_diarization = combined response_status = "active_transcription" - if not tokens and not buffer_transcription and not buffer_diarization: + if not state.tokens and not buffer_transcription and not buffer_diarization: response_status = "no_audio_detected" lines = [] elif response_status == "active_transcription" and not lines: @@ -468,32 +453,19 @@ class AudioProcessor: end=state.get("end_buffer", 0) )] - response = { - "status": response_status, - "lines": [line.to_dict() for line in lines], - "buffer_transcription": buffer_transcription, - "buffer_diarization": buffer_diarization, - "remaining_time_transcription": state["remaining_time_transcription"], - "remaining_time_diarization": state["remaining_time_diarization"] if self.args.diarization else 0 - } - - current_response_signature = f"{response_status} | " + \ - ' '.join([f"{line.speaker} {line.text}" for line in lines]) + \ - f" | {buffer_transcription} | {buffer_diarization}" - - trans = state["remaining_time_transcription"] - diar = state["remaining_time_diarization"] - should_push = ( - current_response_signature != self.last_response_content - or last_sent_trans is None - or round(trans, 1) != round(last_sent_trans, 1) - or round(diar, 1) != round(last_sent_diar, 1) + response = FrontData( + status=response_status, + lines=lines, + buffer_transcription=buffer_transcription, + buffer_diarization=buffer_diarization, + remaining_time_transcription=state.remaining_time_transcription, + remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 ) - if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected" or trans > 0 or diar > 0): + + should_push = (response != self.last_response_content) + if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): yield response - self.last_response_content = current_response_signature - last_sent_trans = trans - last_sent_diar = diar + self.last_response_content = response # Check for termination condition if self.is_stopping: @@ -507,12 +479,12 @@ class AudioProcessor: logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") return - await asyncio.sleep(0.1) # Avoid overwhelming the client + await asyncio.sleep(0.05) except Exception as e: logger.warning(f"Exception in results_formatter: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") - await asyncio.sleep(0.5) # Back off on error + await asyncio.sleep(0.5) async def create_tasks(self): """Create and start processing tasks.""" @@ -523,15 +495,10 @@ class AudioProcessor: if not success: logger.error("Failed to start FFmpeg manager") async def error_generator(): - yield { - "status": "error", - "error": "FFmpeg failed to start. Please check that FFmpeg is installed.", - "lines": [], - "buffer_transcription": "", - "buffer_diarization": "", - "remaining_time_transcription": 0, - "remaining_time_diarization": 0 - } + yield FrontData( + status="error", + error="FFmpeg failed to start. Please check that FFmpeg is installed." + ) return error_generator() if self.args.transcription and self.online: diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index fa4b9da..cd65829 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -54,7 +54,7 @@ async def handle_websocket_results(websocket, results_generator): """Consumes results from the audio processor and sends them via WebSocket.""" try: async for response in results_generator: - await websocket.send_json(response) + await websocket.send_json(response.to_dict()) # when the results_generator finishes it means all audio has been processed logger.info("Results generator finished. Sending 'ready_to_stop' to client.") await websocket.send_json({"type": "ready_to_stop"}) diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index dee4402..1526ef1 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -46,15 +46,14 @@ def append_token_to_last_line(lines, sep, token, debug_info): lines[-1].text += sep + token.text + debug_info lines[-1].end = token.end -def format_output(state, silence, current_time, args, debug): +def format_output(state, silence, current_time, args, debug, sep): diarization = args.diarization disable_punctuation_split = args.disable_punctuation_split - tokens = state["tokens"] - translated_segments = state["translated_segments"] # Here we will attribute the speakers only based on the timestamps of the segments - buffer_transcription = state["buffer_transcription"] - buffer_diarization = state["buffer_diarization"] - end_attributed_speaker = state["end_attributed_speaker"] - sep = state["sep"] + tokens = state.tokens + translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments + buffer_transcription = state.buffer_transcription + buffer_diarization = state.buffer_diarization + end_attributed_speaker = state.end_attributed_speaker previous_speaker = -1 lines = [] @@ -128,7 +127,7 @@ def format_output(state, silence, current_time, args, debug): for line in lines: while cts_idx < len(translated_segments): ts = translated_segments[cts_idx] - if ts.start and ts.start >= line.start and ts.end <= line.end: + if ts and ts.start and ts.start >= line.start and ts.end <= line.end: line.translation += ts.text + ' ' cts_idx += 1 else: diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 09545b5..3acf7c8 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from datetime import timedelta @@ -57,4 +57,38 @@ class Line(TimedText): 'translation': self.translation, 'start': format_time(self.start), 'end': format_time(self.end), - } \ No newline at end of file + } + +@dataclass +class FrontData(): + status: str = '' + error: str = '' + lines: list[Line] = field(default_factory=list) + buffer_transcription: str = '' + buffer_diarization: str = '' + remaining_time_transcription: float = 0. + remaining_time_diarization: float = 0. + + def to_dict(self): + _dict = { + 'status': self.status, + 'lines': [line.to_dict() for line in self.lines], + 'buffer_transcription': self.buffer_transcription, + 'buffer_diarization': self.buffer_diarization, + 'remaining_time_transcription': self.remaining_time_transcription, + 'remaining_time_diarization': self.remaining_time_diarization, + } + if self.error: + _dict['error'] = self.error + return _dict + +@dataclass +class State(): + tokens: list + translated_segments: list + buffer_transcription: str + buffer_diarization: str + end_buffer: float + end_attributed_speaker: float + remaining_time_transcription: float + remaining_time_diarization: float \ No newline at end of file