translation now separates validated from output buffer tokens

This commit is contained in:
Quentin Fuxa
2025-10-26 18:51:09 +01:00
parent 9434390ad3
commit 4e455b8aab

View File

@@ -35,17 +35,19 @@ class TranslationModel():
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2'
if nllb_backend=='ctranslate2':
model = f'nllb-200-distilled-{nllb_size}-ctranslate2'
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
translator = ctranslate2.Translator(MODEL,device=device)
huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model)
translator = ctranslate2.Translator(model,device=device)
elif nllb_backend=='transformers':
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}")
model = f"facebook/nllb-200-distilled-{nllb_size}"
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model)
tokenizer = dict()
for src_lang in src_langs:
if src_lang != 'auto':
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True)
translation_model = TranslationModel(
translator=translator,
@@ -61,7 +63,7 @@ def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.input_buffer = []
self.len_processed_buffer = 0
self.translation_remaining = Translation()
self.validated = []
@@ -72,13 +74,13 @@ class OnlineTranslation:
def compute_common_prefix(self, results):
#we dont want want to prune the result for the moment.
if not self.buffer:
self.buffer = results
if not self.input_buffer:
self.input_buffer = results
else:
for i in range(min(len(self.buffer), len(results))):
if self.buffer[i] != results[i]:
self.commited.extend(self.buffer[:i])
self.buffer = results[i:]
for i in range(min(len(self.input_buffer), len(results))):
if self.input_buffer[i] != results[i]:
self.commited.extend(self.input_buffer[:i])
self.input_buffer = results[i:]
def translate(self, input, input_lang, output_lang):
if not input:
@@ -122,28 +124,28 @@ class OnlineTranslation:
def insert_tokens(self, tokens):
self.buffer.extend(tokens)
self.input_buffer.extend(tokens)
pass
def process(self):
i = 0
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.buffer):
if self.buffer[i].is_punctuation():
translation_sentence = self.translate_tokens(self.buffer[:i+1])
while i < len(self.input_buffer):
if self.input_buffer[i].is_punctuation():
translation_sentence = self.translate_tokens(self.input_buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]
self.input_buffer = self.input_buffer[i+1:]
i = 0
else:
i+=1
self.translation_remaining = self.translate_tokens(self.buffer)
self.len_processed_buffer = len(self.buffer)
return self.validated + [self.translation_remaining]
self.translation_remaining = self.translate_tokens(self.input_buffer)
self.len_processed_buffer = len(self.input_buffer)
return self.validated, [self.translation_remaining]
def insert_silence(self, silence_duration: float):
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
self.buffer = []
self.input_buffer = []
self.validated += [self.translation_remaining]
if __name__ == '__main__':
@@ -157,13 +159,13 @@ if __name__ == '__main__':
test = test_string.split(' ')
step = len(test) // 3
shared_model = load_model([input_lang], nllb_backend='ctranslate2')
shared_model = load_model([input_lang], nllb_backend='transformers')
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
beg_inference = time.time()
for id in range(5):
val = test[id*step : (id+1)*step]
val_str = ' '.join(val)
result = online_translation.translate(val_str)
result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang)
print(result)
print('inference time:', time.time() - beg_inference)