improve buffering when use of heavy models

This commit is contained in:
Quentin Fuxa
2025-04-27 23:52:00 +02:00
parent bbd4fd6cff
commit 1e67bf97f0

View File

@@ -31,13 +31,29 @@ def cut_at(cumulative_pcm, cut_sec):
async def get_all_from_queue(queue):
items = []
try:
while True:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
pass
return items
first_item = await queue.get()
queue.task_done()
if first_item is SENTINEL:
return first_item
if isinstance(first_item, Silence):
return first_item
items.append(first_item)
while True:
if not queue._queue:
break
next_item = queue._queue[0]
if next_item is SENTINEL:
break
if isinstance(next_item, Silence):
break
items.append(await queue.get())
queue.task_done()
if isinstance(items[0], np.ndarray):
return np.concatenate(items)
else: #translation
return items
class AudioProcessor:
"""
@@ -246,10 +262,10 @@ class AudioProcessor:
while True:
try:
item = await self.transcription_queue.get()
# item = await self.transcription_queue.get()
item = await get_all_from_queue(self.transcription_queue)
if item is SENTINEL:
logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done()
break
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
@@ -277,7 +293,7 @@ class AudioProcessor:
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
elif isinstance(item, ChangeSpeaker):
self.transcription.new_speaker(item)
self.transcription_queue.task_done()
# self.transcription_queue.task_done()
continue
elif isinstance(item, np.ndarray):
pcm_array = item
@@ -288,8 +304,8 @@ class AudioProcessor:
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
new_tokens = new_tokens or []
else:
self.transcription_queue.task_done()
continue
return
_buffer_transcript = self.transcription.get_buffer()
buffer_text = _buffer_transcript.text
@@ -316,10 +332,7 @@ class AudioProcessor:
if self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
self.transcription_queue.task_done()
await self.translation_queue.put(token)
except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
@@ -343,7 +356,7 @@ class AudioProcessor:
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0))
while True:
try:
item = await self.diarization_queue.get()
item = await get_all_from_queue(self.diarization_queue)
if item is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
@@ -406,51 +419,24 @@ class AudioProcessor:
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
item = await self.translation_queue.get() #block until at least 1 token
if item is SENTINEL:
tokens_to_process = await get_all_from_queue(self.translation_queue)
if tokens_to_process is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break
elif type(item) is Silence:
self.translation.insert_silence(item.duration)
continue
# get all the available tokens for translation. The more words, the more precise
tokens_to_process = [item]
additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False
for additional_token in additional_tokens:
if additional_token is SENTINEL:
sentinel_found = True
break
elif type(additional_token) is Silence and additional_token.has_ended:
self.translation.insert_silence(additional_token.duration)
continue
else:
tokens_to_process.append(additional_token)
elif type(tokens_to_process) is Silence:
if tokens_to_process.has_ended:
self.translation.insert_silence(tokens_to_process.duration)
continue
if tokens_to_process:
self.translation.insert_tokens(tokens_to_process)
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.translation_validated_segments = translation_validated_segments
self.state.buffer_translation = buffer_translation
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and item is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.")
async def results_formatter(self):