mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
137 lines
5.1 KiB
Python
137 lines
5.1 KiB
Python
import ctranslate2
|
||
import torch
|
||
import transformers
|
||
from dataclasses import dataclass
|
||
import huggingface_hub
|
||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||
from whisperlivekit.timed_objects import Translation
|
||
|
||
|
||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||
|
||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||
|
||
|
||
@dataclass
|
||
class TranslationModel():
|
||
translator: ctranslate2.Translator
|
||
tokenizer: dict
|
||
|
||
def load_model(src_langs):
|
||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||
MODEL_GUY = 'entai2965'
|
||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
translator = ctranslate2.Translator(MODEL,device=device)
|
||
tokenizer = dict()
|
||
for src_lang in src_langs:
|
||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||
return TranslationModel(
|
||
translator=translator,
|
||
tokenizer=tokenizer
|
||
)
|
||
|
||
def translate(input, translation_model, tgt_lang):
|
||
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
|
||
target_prefix = [tgt_lang]
|
||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||
target = results[0].hypotheses[0][1:]
|
||
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
|
||
|
||
class OnlineTranslation:
|
||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||
self.buffer = []
|
||
self.len_processed_buffer = 0
|
||
self.translation_remaining = Translation()
|
||
self.validated = []
|
||
self.translation_pending_validation = ''
|
||
self.translation_model = translation_model
|
||
self.input_languages = input_languages
|
||
self.output_languages = output_languages
|
||
|
||
def compute_common_prefix(self, results):
|
||
#we dont want want to prune the result for the moment.
|
||
if not self.buffer:
|
||
self.buffer = results
|
||
else:
|
||
for i in range(min(len(self.buffer), len(results))):
|
||
if self.buffer[i] != results[i]:
|
||
self.commited.extend(self.buffer[:i])
|
||
self.buffer = results[i:]
|
||
|
||
def translate(self, input, input_lang=None, output_lang=None):
|
||
if not input:
|
||
return ""
|
||
if input_lang is None:
|
||
input_lang = self.input_languages[0]
|
||
if output_lang is None:
|
||
output_lang = self.output_languages[0]
|
||
nllb_output_lang = get_nllb_code(output_lang)
|
||
|
||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
||
target = results[0].hypotheses[0][1:]
|
||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||
return results
|
||
|
||
def translate_tokens(self, tokens):
|
||
if tokens:
|
||
text = ' '.join([token.text for token in tokens])
|
||
start = tokens[0].start
|
||
end = tokens[-1].end
|
||
translated_text = self.translate(text)
|
||
translation = Translation(
|
||
text=translated_text,
|
||
start=start,
|
||
end=end,
|
||
)
|
||
return translation
|
||
return None
|
||
|
||
|
||
|
||
def insert_tokens(self, tokens):
|
||
self.buffer.extend(tokens)
|
||
pass
|
||
|
||
def process(self):
|
||
i = 0
|
||
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||
return self.validated + [self.translation_remaining]
|
||
while i < len(self.buffer):
|
||
if self.buffer[i].text in PUNCTUATION_MARKS:
|
||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||
self.validated.append(translation_sentence)
|
||
self.buffer = self.buffer[i+1:]
|
||
i = 0
|
||
else:
|
||
i+=1
|
||
self.translation_remaining = self.translate_tokens(self.buffer)
|
||
self.len_processed_buffer = len(self.buffer)
|
||
return self.validated + [self.translation_remaining]
|
||
|
||
|
||
if __name__ == '__main__':
|
||
output_lang = 'fr'
|
||
input_lang = "en"
|
||
|
||
|
||
test_string = """
|
||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||
"""
|
||
test = test_string.split(' ')
|
||
step = len(test) // 3
|
||
|
||
shared_model = load_model([input_lang])
|
||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||
|
||
for id in range(5):
|
||
val = test[id*step : (id+1)*step]
|
||
val_str = ' '.join(val)
|
||
result = online_translation.translate(val_str)
|
||
print(result)
|
||
|
||
|
||
|
||
|
||
# print(result) |