diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index e5ca35c..85f6797 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np. if isinstance(first_item, Silence): return first_item items.append(first_item) - + while True: if not queue._queue: break @@ -53,15 +53,15 @@ class AudioProcessor: Processes audio streams for transcription and diarization. Handles audio processing, state management, and result formatting. """ - + def __init__(self, **kwargs: Any) -> None: """Initialize the audio processor with configuration, models, and state.""" - + if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): models = kwargs['transcription_engine'] else: models = TranscriptionEngine(**kwargs) - + # Audio processing settings self.args = models.args self.sample_rate = 16000 @@ -86,13 +86,13 @@ class AudioProcessor: # Models and processing self.asr: Any = models.asr self.vac: Optional[FixedVADIterator] = None - + if self.args.vac: if models.vac_session is not None: vac_model = OnnxWrapper(session=models.vac_session) self.vac = FixedVADIterator(vac_model) else: - self.vac = FixedVADIterator(load_jit_vad()) + self.vac = FixedVADIterator(load_jit_vad()) self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None self._ffmpeg_error: Optional[str] = None @@ -106,7 +106,7 @@ class AudioProcessor: logger.error(f"FFmpeg error: {error_type}") self._ffmpeg_error = error_type self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error - + self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None @@ -117,14 +117,14 @@ class AudioProcessor: self.translation_task: Optional[asyncio.Task] = None self.watchdog_task: Optional[asyncio.Task] = None self.all_tasks_for_cleanup: List[asyncio.Task] = [] - + self.transcription: Optional[Any] = None self.translation: Optional[Any] = None self.diarization: Optional[Any] = None if self.args.transcription: - self.transcription = online_factory(self.args, models.asr) - self.sep = self.transcription.asr.sep + self.transcription = online_factory(self.args, models.asr) + self.sep = self.transcription.asr.sep if self.args.diarization: self.diarization = online_diarization_factory(self.args, models.diarization_model) if models.translation_model: @@ -182,24 +182,24 @@ class AudioProcessor: def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray: """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 - + async def get_current_state(self) -> State: """Get current state.""" async with self.lock: current_time = time() - + remaining_transcription = 0 if self.state.end_buffer > 0: remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1)) - + remaining_diarization = 0 if self.state.tokens: latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) - + self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_diarization = remaining_diarization - + return self.state async def ffmpeg_stdout_reader(self) -> None: @@ -255,7 +255,7 @@ class AudioProcessor: async def transcription_processor(self) -> None: """Process audio chunks for transcription.""" cumulative_pcm_duration_stream_time = 0.0 - + while True: try: # item = await self.transcription_queue.get() @@ -311,12 +311,12 @@ class AudioProcessor: if new_tokens: candidate_end_times.append(new_tokens[-1].end) - + if _buffer_transcript.end is not None: candidate_end_times.append(_buffer_transcript.end) - + candidate_end_times.append(current_audio_processed_upto) - + async with self.lock: self.state.tokens.extend(new_tokens) self.state.buffer_transcription = _buffer_transcript @@ -326,13 +326,13 @@ class AudioProcessor: if self.translation_queue: for token in new_tokens: - await self.translation_queue.put(token) + await self.translation_queue.put(token) except Exception as e: logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue self.transcription_queue.task_done() - + if self.is_stopping: logger.info("Transcription processor finishing due to stopping flag.") if self.diarization_queue: @@ -353,18 +353,21 @@ class AudioProcessor: if item.has_ended: self.diarization.insert_silence(item.duration) continue - self.diarization.insert_audio_chunk(item) diarization_segments = await self.diarization.diarize() - self.state.new_diarization = diarization_segments - + diar_end = 0.0 + if diarization_segments: + diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments) + async with self.lock: + self.state.new_diarization = diarization_segments + self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end) except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") logger.info("Diarization processor task finished.") async def translation_processor(self) -> None: - # the idea is to ignore diarization for the moment. We use only transcription tokens. + # the idea is to ignore diarization for the moment. We use only transcription tokens. # And the speaker is attributed given the segments used for the translation # in the future we want to have different languages for each speaker etc, so it will be more complex. while True: @@ -426,22 +429,22 @@ class AudioProcessor: remaining_time_transcription=state.remaining_time_transcription, remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 ) - + should_push = (response != self.last_response_content) if should_push: yield response self.last_response_content = response - + if self.is_stopping and self._processing_tasks_done(): logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") return - + await asyncio.sleep(0.05) - + except Exception as e: logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) - + async def create_tasks(self) -> AsyncGenerator[FrontData, None]: """Create and start processing tasks.""" self.all_tasks_for_cleanup = [] @@ -466,21 +469,21 @@ class AudioProcessor: self.transcription_task = asyncio.create_task(self.transcription_processor()) self.all_tasks_for_cleanup.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task) - + if self.diarization: self.diarization_task = asyncio.create_task(self.diarization_processor()) self.all_tasks_for_cleanup.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task) - + if self.translation: self.translation_task = asyncio.create_task(self.translation_processor()) self.all_tasks_for_cleanup.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task) - + # Monitor overall system health self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) self.all_tasks_for_cleanup.append(self.watchdog_task) - + return self.results_formatter() async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None: @@ -493,7 +496,7 @@ class AudioProcessor: return await asyncio.sleep(10) - + for i, task in enumerate(list(tasks_remaining)): if task.done(): exc = task.exception() @@ -503,13 +506,13 @@ class AudioProcessor: else: logger.info(f"{task_name} completed normally.") tasks_remaining.remove(task) - + except asyncio.CancelledError: logger.info("Watchdog task cancelled.") break except Exception as e: logger.error(f"Error in watchdog task: {e}", exc_info=True) - + async def cleanup(self) -> None: """Clean up resources when processing is complete.""" logger.info("Starting cleanup of AudioProcessor resources.") @@ -517,7 +520,7 @@ class AudioProcessor: for task in self.all_tasks_for_cleanup: if task and not task.done(): task.cancel() - + created_tasks = [t for t in self.all_tasks_for_cleanup if t] if created_tasks: await asyncio.gather(*created_tasks, return_exceptions=True) @@ -555,7 +558,7 @@ class AudioProcessor: if not message: logger.info("Empty audio message received, initiating stop sequence.") self.is_stopping = True - + if self.transcription_queue: await self.transcription_queue.put(SENTINEL) @@ -596,7 +599,7 @@ class AudioProcessor: chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec) aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample - + if aligned_chunk_size == 0: return pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size]) @@ -613,7 +616,7 @@ class AudioProcessor: if res is not None: if "start" in res and self.current_silence: await self._end_silence() - + if "end" in res and not self.current_silence: pre_silence_chunk = self._slice_before_silence( pcm_array, chunk_sample_start, res.get("end")