mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
translation now separates validated from output buffer tokens
This commit is contained in:
@@ -35,17 +35,19 @@ class TranslationModel():
|
||||
|
||||
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
||||
|
||||
if nllb_backend=='ctranslate2':
|
||||
model = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model)
|
||||
translator = ctranslate2.Translator(model,device=device)
|
||||
elif nllb_backend=='transformers':
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}")
|
||||
model = f"facebook/nllb-200-distilled-{nllb_size}"
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model)
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
translation_model = TranslationModel(
|
||||
translator=translator,
|
||||
@@ -61,7 +63,7 @@ def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.input_buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
@@ -72,13 +74,13 @@ class OnlineTranslation:
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.buffer:
|
||||
self.buffer = results
|
||||
if not self.input_buffer:
|
||||
self.input_buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.buffer), len(results))):
|
||||
if self.buffer[i] != results[i]:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
for i in range(min(len(self.input_buffer), len(results))):
|
||||
if self.input_buffer[i] != results[i]:
|
||||
self.commited.extend(self.input_buffer[:i])
|
||||
self.input_buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang, output_lang):
|
||||
if not input:
|
||||
@@ -122,28 +124,28 @@ class OnlineTranslation:
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
self.input_buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].is_punctuation():
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
while i < len(self.input_buffer):
|
||||
if self.input_buffer[i].is_punctuation():
|
||||
translation_sentence = self.translate_tokens(self.input_buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
self.input_buffer = self.input_buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||||
self.len_processed_buffer = len(self.buffer)
|
||||
return self.validated + [self.translation_remaining]
|
||||
self.translation_remaining = self.translate_tokens(self.input_buffer)
|
||||
self.len_processed_buffer = len(self.input_buffer)
|
||||
return self.validated, [self.translation_remaining]
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||
self.buffer = []
|
||||
self.input_buffer = []
|
||||
self.validated += [self.translation_remaining]
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -157,13 +159,13 @@ if __name__ == '__main__':
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang], nllb_backend='ctranslate2')
|
||||
shared_model = load_model([input_lang], nllb_backend='transformers')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang)
|
||||
print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
Reference in New Issue
Block a user