From add7ea07ee009b99f4ed3468d7cbff4227842ece Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 9 Sep 2025 19:55:39 +0200 Subject: [PATCH] translator takes all the tokens from the queue --- whisperlivekit/audio_processor.py | 62 +++++++++++++++++-- whisperlivekit/core.py | 6 +- whisperlivekit/results_formater.py | 3 +- whisperlivekit/simul_whisper/simul_whisper.py | 1 - whisperlivekit/timed_objects.py | 4 ++ whisperlivekit/translation/translation.py | 41 +++++++++++- 6 files changed, 105 insertions(+), 12 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 98d2557..c99d013 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -16,6 +16,17 @@ logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker + +async def get_all_from_queue(queue): + items = [] + try: + while True: + item = queue.get_nowait() + items.append(item) + except asyncio.QueueEmpty: + pass + return items + class AudioProcessor: """ Processes audio streams for transcription and diarization. @@ -265,6 +276,8 @@ class AudioProcessor: if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(SENTINEL) logger.debug("Sentinel put into diarization_queue.") + if self.args.target_language and self.translation_queue: + await self.translation_queue.put(SENTINEL) async def transcription_processor(self): @@ -308,9 +321,6 @@ class AudioProcessor: cumulative_pcm_duration_stream_time += duration_this_chunk stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time - - - self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) new_tokens, current_audio_processed_upto = self.online.process_iter() @@ -338,6 +348,11 @@ class AudioProcessor: await self.update_transcription( new_tokens, buffer_text, new_end_buffer, self.sep ) + + if new_tokens and self.args.target_language and self.translation_queue: + for token in new_tokens: + await self.translation_queue.put(token) + self.transcription_queue.task_done() except Exception as e: @@ -398,9 +413,44 @@ class AudioProcessor: # in the future we want to have different languages for each speaker etc, so it will be more complex. while True: try: - item = await self.translation_queue.get() + token = await self.translation_queue.get() #block until at least 1 token + if token is SENTINEL: + logger.debug("Translation processor received sentinel. Finishing.") + self.translation_queue.task_done() + break + + # get all the available tokens for translation. The more words, the more precise + tokens_to_process = [token] + additional_tokens = await get_all_from_queue(self.translation_queue) + + sentinel_found = False + for additional_token in additional_tokens: + if additional_token is SENTINEL: + sentinel_found = True + break + tokens_to_process.append(additional_token) + if tokens_to_process: + online_translation.insert_tokens(tokens_to_process) + translations = online_translation.process() + print(translations) + + self.translation_queue.task_done() + for _ in additional_tokens: + self.translation_queue.task_done() + + if sentinel_found: + logger.debug("Translation processor received sentinel in batch. Finishing.") + break + except Exception as e: logger.warning(f"Exception in translation_processor: {e}") + logger.warning(f"Traceback: {traceback.format_exc()}") + if 'token' in locals() and token is not SENTINEL: + self.translation_queue.task_done() + if 'additional_tokens' in locals(): + for _ in additional_tokens: + self.translation_queue.task_done() + logger.info("Translation processor task finished.") async def results_formatter(self): """Format processing results for output.""" @@ -546,8 +596,10 @@ class AudioProcessor: self.all_tasks_for_cleanup.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task) - if self.args.target_language and self.args.language != 'auto': + if self.args.target_language and self.args.lan != 'auto': self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation)) + self.all_tasks_for_cleanup.append(self.translation_task) + processing_tasks_for_watchdog.append(self.translation_task) self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index b3bbbe7..f59e629 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -136,11 +136,11 @@ class TranscriptionEngine: self.translation_model = None if self.args.target_language: - if self.args.language == 'auto': + if self.args.lan == 'auto': raise Exception('Translation cannot be set with language auto') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model([self.args.language]) #in the future we want to handle different languages for different speakers + self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True @@ -181,4 +181,4 @@ def online_translation_factory(args, translation_model): #one shared nllb model for all speaker #one tokenizer per speaker/language from whisperlivekit.translation.translation import OnlineTranslation - online = OnlineTranslation(translation_model, [args.language], [args.target_language]) \ No newline at end of file + return OnlineTranslation(translation_model, [args.lan], [args.target_language]) \ No newline at end of file diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 3abe4f5..e6f664d 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -6,7 +6,7 @@ from whisperlivekit.remove_silences import handle_silences logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -PUNCTUATION_MARKS = {'.', '!', '?'} +PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} CHECK_AROUND = 4 def format_time(seconds: float) -> str: @@ -59,6 +59,7 @@ def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized): def format_output(state, silence, current_time, diarization, debug): tokens = state["tokens"] + translated_tokens = state["translated_tokens"] # Here we will attribute the speakers only based on the timestamps of the segments buffer_transcription = state["buffer_transcription"] buffer_diarization = state["buffer_diarization"] end_attributed_speaker = state["end_attributed_speaker"] diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index b537510..cc183c4 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -174,7 +174,6 @@ class PaddedAlignAttWhisper: self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) def remove_hooks(self): - print('remove hook') for hook in self.l_hooks: hook.remove() diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index c2a3619..c8ad3ec 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -31,6 +31,10 @@ class SpeakerSegment(TimedText): """ pass +@dataclass +class Translation(TimedText): + pass + @dataclass class Silence(): duration: float \ No newline at end of file diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index edc6677..7a4c734 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -4,15 +4,18 @@ import transformers from dataclasses import dataclass import huggingface_hub from whisperlivekit.translation.mapping_languages import get_nllb_code +from timed_objects import Translation + #In diarization case, we may want to translate just one speaker, or at least start the sentences there +PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} @dataclass class TranslationModel(): translator: ctranslate2.Translator - tokenizer: dict() + tokenizer: dict def load_model(src_langs): MODEL = 'nllb-200-distilled-600M-ctranslate2' @@ -38,7 +41,8 @@ def translate(input, translation_model, tgt_lang): class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): self.buffer = [] - self.commited = [] + self.validated = [] + self.translation_pending_validation = '' self.translation_model = translation_model self.input_languages = input_languages self.output_languages = output_languages @@ -68,6 +72,39 @@ class OnlineTranslation: results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target)) return results + def translate_tokens(self, tokens): + if tokens: + text = ' '.join([token.text for token in tokens]) + start = tokens[0].start + end = tokens[-1].end + translated_text = self.translate(text) + translation = Translation( + text=translated_text, + start=start, + end=end, + ) + return translation + return None + + + + def insert_tokens(self, tokens): + self.buffer.extend(tokens) + pass + + def process(self): + i = 0 + while i < len(self.buffer): + if self.buffer[i].text in PUNCTUATION_MARKS: + translation_sentence = self.translate_tokens(self.buffer[:i+1]) + self.validated.append(translation_sentence) + self.buffer = self.buffer[i+1:] + i = 0 + else: + i+=1 + translation_remaining = self.translate_tokens(self.buffer) + return self.validated + [translation_remaining] + if __name__ == '__main__': output_lang = 'fr'