diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 213cf52..41225e8 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -31,13 +31,29 @@ def cut_at(cumulative_pcm, cut_sec): async def get_all_from_queue(queue): items = [] - try: - while True: - item = queue.get_nowait() - items.append(item) - except asyncio.QueueEmpty: - pass - return items + + first_item = await queue.get() + queue.task_done() + if first_item is SENTINEL: + return first_item + if isinstance(first_item, Silence): + return first_item + items.append(first_item) + + while True: + if not queue._queue: + break + next_item = queue._queue[0] + if next_item is SENTINEL: + break + if isinstance(next_item, Silence): + break + items.append(await queue.get()) + queue.task_done() + if isinstance(items[0], np.ndarray): + return np.concatenate(items) + else: #translation + return items class AudioProcessor: """ @@ -246,10 +262,10 @@ class AudioProcessor: while True: try: - item = await self.transcription_queue.get() + # item = await self.transcription_queue.get() + item = await get_all_from_queue(self.transcription_queue) if item is SENTINEL: logger.debug("Transcription processor received sentinel. Finishing.") - self.transcription_queue.task_done() break asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE @@ -277,7 +293,7 @@ class AudioProcessor: current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm) elif isinstance(item, ChangeSpeaker): self.transcription.new_speaker(item) - self.transcription_queue.task_done() + # self.transcription_queue.task_done() continue elif isinstance(item, np.ndarray): pcm_array = item @@ -288,8 +304,8 @@ class AudioProcessor: new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter) new_tokens = new_tokens or [] else: - self.transcription_queue.task_done() continue + return _buffer_transcript = self.transcription.get_buffer() buffer_text = _buffer_transcript.text @@ -316,10 +332,7 @@ class AudioProcessor: if self.translation_queue: for token in new_tokens: - await self.translation_queue.put(token) - - self.transcription_queue.task_done() - + await self.translation_queue.put(token) except Exception as e: logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") @@ -343,7 +356,7 @@ class AudioProcessor: await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0)) while True: try: - item = await self.diarization_queue.get() + item = await get_all_from_queue(self.diarization_queue) if item is SENTINEL: logger.debug("Diarization processor received sentinel. Finishing.") self.diarization_queue.task_done() @@ -406,51 +419,24 @@ class AudioProcessor: # in the future we want to have different languages for each speaker etc, so it will be more complex. while True: try: - item = await self.translation_queue.get() #block until at least 1 token - if item is SENTINEL: + tokens_to_process = await get_all_from_queue(self.translation_queue) + if tokens_to_process is SENTINEL: logger.debug("Translation processor received sentinel. Finishing.") self.translation_queue.task_done() break - elif type(item) is Silence: - self.translation.insert_silence(item.duration) - continue - - # get all the available tokens for translation. The more words, the more precise - tokens_to_process = [item] - additional_tokens = await get_all_from_queue(self.translation_queue) - - sentinel_found = False - for additional_token in additional_tokens: - if additional_token is SENTINEL: - sentinel_found = True - break - elif type(additional_token) is Silence and additional_token.has_ended: - self.translation.insert_silence(additional_token.duration) - continue - else: - tokens_to_process.append(additional_token) + elif type(tokens_to_process) is Silence: + if tokens_to_process.has_ended: + self.translation.insert_silence(tokens_to_process.duration) + continue if tokens_to_process: self.translation.insert_tokens(tokens_to_process) translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process) async with self.lock: self.state.translation_validated_segments = translation_validated_segments self.state.buffer_translation = buffer_translation - self.translation_queue.task_done() - for _ in additional_tokens: - self.translation_queue.task_done() - - if sentinel_found: - logger.debug("Translation processor received sentinel in batch. Finishing.") - break - except Exception as e: logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") - if 'token' in locals() and item is not SENTINEL: - self.translation_queue.task_done() - if 'additional_tokens' in locals(): - for _ in additional_tokens: - self.translation_queue.task_done() logger.info("Translation processor task finished.") async def results_formatter(self):