From b6164aa59b5cc00a78d164d5ecb98d9e989d5b06 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Mon, 8 Sep 2025 11:34:40 +0200 Subject: [PATCH] translation device determined with torch.device --- whisperlivekit/translation/translation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index ccfedb5..edea57b 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -1,4 +1,5 @@ import ctranslate2 +import torch import transformers from dataclasses import dataclass import huggingface_hub @@ -10,13 +11,16 @@ class TranslationModel(): tokenizer: transformers.AutoTokenizer def load_model(src_lang): - huggingface_hub.snapshot_download('entai2965/nllb-200-distilled-600M-ctranslate2',local_dir='nllb-200-distilled-600M-ctranslate2') - translator = ctranslate2.Translator("nllb-200-distilled-600M-ctranslate2",device="cpu") - tokenizer = transformers.AutoTokenizer.from_pretrained("nllb-200-distilled-600M-ctranslate2", src_lang=src_lang, clean_up_tokenization_spaces=True) + MODEL = 'nllb-200-distilled-600M-ctranslate2' + MODEL_GUY = 'entai2965' + huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + translator = ctranslate2.Translator(MODEL,device=device) + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) return TranslationModel( translator=translator, tokenizer=tokenizer - ) + ) def translate(input, translation_model, tgt_lang): if not input: