translator takes all the tokens from the queue

This commit is contained in:
Quentin Fuxa
2025-09-09 19:55:39 +02:00
parent da8726b2cb
commit add7ea07ee
6 changed files with 105 additions and 12 deletions

View File

@@ -4,15 +4,18 @@ import transformers
from dataclasses import dataclass
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from timed_objects import Translation
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
tokenizer: dict()
tokenizer: dict
def load_model(src_langs):
MODEL = 'nllb-200-distilled-600M-ctranslate2'
@@ -38,7 +41,8 @@ def translate(input, translation_model, tgt_lang):
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.commited = []
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
@@ -68,6 +72,39 @@ class OnlineTranslation:
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
return results
def translate_tokens(self, tokens):
if tokens:
text = ' '.join([token.text for token in tokens])
start = tokens[0].start
end = tokens[-1].end
translated_text = self.translate(text)
translation = Translation(
text=translated_text,
start=start,
end=end,
)
return translation
return None
def insert_tokens(self, tokens):
self.buffer.extend(tokens)
pass
def process(self):
i = 0
while i < len(self.buffer):
if self.buffer[i].text in PUNCTUATION_MARKS:
translation_sentence = self.translate_tokens(self.buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]
i = 0
else:
i+=1
translation_remaining = self.translate_tokens(self.buffer)
return self.validated + [translation_remaining]
if __name__ == '__main__':
output_lang = 'fr'