diff --git a/README.md b/README.md index 2c267a7..656b5fb 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,8 @@ An important list of parameters can be changed. But what *should* you change? | Translation options | Description | Default | |-----------|-------------|---------| -| `--nllb-backend` | [NOT FUNCTIONNAL YET] transformer or ctranslate2 | `ctranslate2` | +| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` | +| `--nllb-size` | `600M` or `1.3B` | `600M` | > For diarization using Diart, you need access to pyannote.audio models: > 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 7f2eaf4..578e624 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -70,7 +70,8 @@ class TranscriptionEngine: "embedding_model": "pyannote/embedding", # translation params: - "nllb_backend": "ctranslate2" + "nllb_backend": "ctranslate2", + "nllb_size": "600M" } config_dict = {**defaults, **kwargs} @@ -148,8 +149,7 @@ class TranscriptionEngine: raise Exception('Translation cannot be set with language auto') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend) #in the future we want to handle different languages for different speakers - + self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index c73e6d4..55d4173 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -291,7 +291,14 @@ def parse_args(): "--nllb-backend", type=str, default="ctranslate2", - help="transformer or ctranslate2", + help="transformers or ctranslate2", + ) + + simulstreaming_group.add_argument( + "--nllb-size", + type=str, + default="600M", + help="600M or 1.3B", ) args = parser.parse_args() diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index 7923243..c08f190 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -1,4 +1,5 @@ import logging +import time import ctranslate2 import torch import transformers @@ -20,39 +21,29 @@ MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider t class TranslationModel(): translator: ctranslate2.Translator tokenizer: dict + device: str + backend_type: str = 'ctranslate2' -def load_model(src_langs, backend='ctranslate2'): +def load_model(src_langs, backend='ctranslate2', model_size='600M'): + device = "cuda" if torch.cuda.is_available() else "cpu" + MODEL = f'nllb-200-distilled-{model_size}-ctranslate2' if backend=='ctranslate2': - MODEL = 'nllb-200-distilled-600M-ctranslate2' MODEL_GUY = 'entai2965' huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) - device = "cuda" if torch.cuda.is_available() else "cpu" translator = ctranslate2.Translator(MODEL,device=device) - tokenizer = dict() - for src_lang in src_langs: - tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) elif backend=='transformers': - raise Exception('not implemented yet') + translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}") + tokenizer = dict() + for src_lang in src_langs: + tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + return TranslationModel( translator=translator, - tokenizer=tokenizer + tokenizer=tokenizer, + backend_type=backend, + device = device ) -def translate(input, translation_model, tgt_lang, src_lang="en"): - # Get the specific tokenizer for the source language - tokenizer = translation_model.tokenizer[src_lang] - - # Convert input to tokens - source = tokenizer.convert_ids_to_tokens(tokenizer.encode(input)) - - # Translate with target language prefix - target_prefix = [tgt_lang] - results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix]) - - # Get translated tokens and decode - target = results[0].hypotheses[0][1:] - return tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) - class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): self.buffer = [] @@ -83,12 +74,19 @@ class OnlineTranslation: output_lang = self.output_languages[0] nllb_output_lang = get_nllb_code(output_lang) - source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input)) - results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff. - target = results[0].hypotheses[0][1:] - results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target)) - return results - + tokenizer = self.translation_model.tokenizer[input_lang] + tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device) + + if self.translation_model.backend_type == 'ctranslate2': + source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0]) + results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) + target = results[0].hypotheses[0][1:] + result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) + else: + translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang)) + result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] + return result + def translate_tokens(self, tokens): if tokens: text = ' '.join([token.text for token in tokens]) @@ -103,7 +101,6 @@ class OnlineTranslation: return translation return None - def insert_tokens(self, tokens): self.buffer.extend(tokens) @@ -141,16 +138,13 @@ if __name__ == '__main__': test = test_string.split(' ') step = len(test) // 3 - shared_model = load_model([input_lang]) + shared_model = load_model([input_lang], backend='ctranslate2') online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) - + + beg_inference = time.time() for id in range(5): val = test[id*step : (id+1)*step] val_str = ' '.join(val) result = online_translation.translate(val_str) print(result) - - - - - # print(result) + print('inference time:', time.time() - beg_inference) \ No newline at end of file