translation: use of get_nllb_code

This commit is contained in:
Quentin Fuxa
2025-09-07 15:25:14 +02:00
parent 84890b8e61
commit 72f33be6f2
3 changed files with 11 additions and 8 deletions

View File

@@ -133,12 +133,14 @@ class TranscriptionEngine:
self.diarization_model = SortformerDiarization()
else:
raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}")
self.translation_model = None
if self.args.target_language:
if self.args.language == 'auto':
raise Exception('Translation cannot be set with language auto')
else:
from whisperlivekit.translation.translation import load_model
self.translation_model = load_model()
TranscriptionEngine._initialized = True

View File

@@ -132,7 +132,7 @@ NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
def get_nllb_code(crowdin_code):
return CROWDIN_TO_NLLB.get(crowdin_code, crowdin_code)
return CROWDIN_TO_NLLB.get(crowdin_code, None)
def get_crowdin_code(nllb_code):

View File

@@ -2,8 +2,7 @@ import ctranslate2
import transformers
from dataclasses import dataclass
import huggingface_hub
src_lang = "eng_Latn"
from .mapping_languages import get_nllb_code
@dataclass
class TranslationModel():
@@ -30,8 +29,10 @@ def translate(input, translation_model, tgt_lang):
if __name__ == '__main__':
tgt_lang = "fra_Latn"
src_lang = "eng_Latn"
translation_model = load_model(src_lang)
result = translate('Hello world', translation_model=translation_model, tgt_lang=tgt_lang)
tgt_lang = 'fr'
src_lang = "en"
nllb_tgt_lang = get_nllb_code(tgt_lang)
nllb_src_lang = get_nllb_code(src_lang)
translation_model = load_model(nllb_src_lang)
result = translate('Hello world', translation_model=translation_model, tgt_lang=nllb_tgt_lang)
print(result)