translation device determined with torch.device

This commit is contained in:
Quentin Fuxa
2025-09-08 11:34:40 +02:00
parent 4209d7f7c0
commit b6164aa59b

View File

@@ -1,4 +1,5 @@
import ctranslate2
import torch
import transformers
from dataclasses import dataclass
import huggingface_hub
@@ -10,13 +11,16 @@ class TranslationModel():
tokenizer: transformers.AutoTokenizer
def load_model(src_lang):
huggingface_hub.snapshot_download('entai2965/nllb-200-distilled-600M-ctranslate2',local_dir='nllb-200-distilled-600M-ctranslate2')
translator = ctranslate2.Translator("nllb-200-distilled-600M-ctranslate2",device="cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained("nllb-200-distilled-600M-ctranslate2", src_lang=src_lang, clean_up_tokenization_spaces=True)
MODEL = 'nllb-200-distilled-600M-ctranslate2'
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
translator = ctranslate2.Translator(MODEL,device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
return TranslationModel(
translator=translator,
tokenizer=tokenizer
)
)
def translate(input, translation_model, tgt_lang):
if not input: