diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index ccfedb5..edea57b 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -1,4 +1,5 @@ import ctranslate2 +import torch import transformers from dataclasses import dataclass import huggingface_hub @@ -10,13 +11,16 @@ class TranslationModel(): tokenizer: transformers.AutoTokenizer def load_model(src_lang): - huggingface_hub.snapshot_download('entai2965/nllb-200-distilled-600M-ctranslate2',local_dir='nllb-200-distilled-600M-ctranslate2') - translator = ctranslate2.Translator("nllb-200-distilled-600M-ctranslate2",device="cpu") - tokenizer = transformers.AutoTokenizer.from_pretrained("nllb-200-distilled-600M-ctranslate2", src_lang=src_lang, clean_up_tokenization_spaces=True) + MODEL = 'nllb-200-distilled-600M-ctranslate2' + MODEL_GUY = 'entai2965' + huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + translator = ctranslate2.Translator(MODEL,device=device) + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) return TranslationModel( translator=translator, tokenizer=tokenizer - ) + ) def translate(input, translation_model, tgt_lang): if not input: