From 4e455b8aab03169921ed3303ec3d7a1486a14c56 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 26 Oct 2025 18:51:09 +0100 Subject: [PATCH] translation now separates validated from output buffer tokens --- whisperlivekit/translation/translation.py | 50 ++++++++++++----------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index 3cdce4d..f1d5c14 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -35,17 +35,19 @@ class TranslationModel(): def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'): device = "cuda" if torch.cuda.is_available() else "cpu" - MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2' + if nllb_backend=='ctranslate2': + model = f'nllb-200-distilled-{nllb_size}-ctranslate2' MODEL_GUY = 'entai2965' - huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) - translator = ctranslate2.Translator(MODEL,device=device) + huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model) + translator = ctranslate2.Translator(model,device=device) elif nllb_backend=='transformers': - translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}") + model = f"facebook/nllb-200-distilled-{nllb_size}" + translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model) tokenizer = dict() for src_lang in src_langs: if src_lang != 'auto': - tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True) translation_model = TranslationModel( translator=translator, @@ -61,7 +63,7 @@ def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'): class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): - self.buffer = [] + self.input_buffer = [] self.len_processed_buffer = 0 self.translation_remaining = Translation() self.validated = [] @@ -72,13 +74,13 @@ class OnlineTranslation: def compute_common_prefix(self, results): #we dont want want to prune the result for the moment. - if not self.buffer: - self.buffer = results + if not self.input_buffer: + self.input_buffer = results else: - for i in range(min(len(self.buffer), len(results))): - if self.buffer[i] != results[i]: - self.commited.extend(self.buffer[:i]) - self.buffer = results[i:] + for i in range(min(len(self.input_buffer), len(results))): + if self.input_buffer[i] != results[i]: + self.commited.extend(self.input_buffer[:i]) + self.input_buffer = results[i:] def translate(self, input, input_lang, output_lang): if not input: @@ -122,28 +124,28 @@ class OnlineTranslation: def insert_tokens(self, tokens): - self.buffer.extend(tokens) + self.input_buffer.extend(tokens) pass def process(self): i = 0 - if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process + if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process return self.validated + [self.translation_remaining] - while i < len(self.buffer): - if self.buffer[i].is_punctuation(): - translation_sentence = self.translate_tokens(self.buffer[:i+1]) + while i < len(self.input_buffer): + if self.input_buffer[i].is_punctuation(): + translation_sentence = self.translate_tokens(self.input_buffer[:i+1]) self.validated.append(translation_sentence) - self.buffer = self.buffer[i+1:] + self.input_buffer = self.input_buffer[i+1:] i = 0 else: i+=1 - self.translation_remaining = self.translate_tokens(self.buffer) - self.len_processed_buffer = len(self.buffer) - return self.validated + [self.translation_remaining] + self.translation_remaining = self.translate_tokens(self.input_buffer) + self.len_processed_buffer = len(self.input_buffer) + return self.validated, [self.translation_remaining] def insert_silence(self, silence_duration: float): if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER: - self.buffer = [] + self.input_buffer = [] self.validated += [self.translation_remaining] if __name__ == '__main__': @@ -157,13 +159,13 @@ if __name__ == '__main__': test = test_string.split(' ') step = len(test) // 3 - shared_model = load_model([input_lang], nllb_backend='ctranslate2') + shared_model = load_model([input_lang], nllb_backend='transformers') online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) beg_inference = time.time() for id in range(5): val = test[id*step : (id+1)*step] val_str = ' '.join(val) - result = online_translation.translate(val_str) + result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang) print(result) print('inference time:', time.time() - beg_inference) \ No newline at end of file