mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
translation device determined with torch.device
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass
|
||||
import huggingface_hub
|
||||
@@ -10,13 +11,16 @@ class TranslationModel():
|
||||
tokenizer: transformers.AutoTokenizer
|
||||
|
||||
def load_model(src_lang):
|
||||
huggingface_hub.snapshot_download('entai2965/nllb-200-distilled-600M-ctranslate2',local_dir='nllb-200-distilled-600M-ctranslate2')
|
||||
translator = ctranslate2.Translator("nllb-200-distilled-600M-ctranslate2",device="cpu")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained("nllb-200-distilled-600M-ctranslate2", src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
return TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
)
|
||||
|
||||
def translate(input, translation_model, tgt_lang):
|
||||
if not input:
|
||||
|
||||
Reference in New Issue
Block a user