Use optional new separate NLLW package for translation

This commit is contained in:
Quentin Fuxa
2025-10-30 19:36:28 +01:00
parent 939a7ebf8b
commit ece02db6a3
10 changed files with 12 additions and 607 deletions

View File

@@ -18,8 +18,8 @@ Real-time transcription directly to your browser, with a ready-to-use backend+se
#### Powered by Leading Research:
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufalSimul/Streaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
@@ -68,9 +68,9 @@ Go to `chrome-extension` for instructions.
| Optional | `pip install` |
|-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| **Apple Silicon optimized backend** | `mlx-whisper` |
| **NLLB Translation** | `huggingface_hub` & `transformers` |
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| **Apple Silicon optimizations** | `mlx-whisper` |
| **Translation** | `nllw` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |

View File

@@ -26,7 +26,7 @@ whisperlivekit-server --target-language fra_Latn
### Python API
```python
from whisperlivekit.translation import get_language_info
from nllw.translation import get_language_info
# Get language information by name
lang_info = get_language_info("French")

View File

@@ -41,7 +41,7 @@ dependencies = [
]
[project.optional-dependencies]
translation = ["transformers", "huggingface_hub"]
translation = ["nllw"]
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
[project.urls]
@@ -60,7 +60,6 @@ packages = [
"whisperlivekit.simul_whisper.whisper.normalizers",
"whisperlivekit.web",
"whisperlivekit.whisper_streaming_custom",
"whisperlivekit.translation",
"whisperlivekit.vad_models"
]

View File

@@ -141,9 +141,9 @@ class TranscriptionEngine:
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else:
from whisperlivekit.translation.translation import load_model
from nllw import load_model
translation_params = {
"nllb_backend": "ctranslate2",
"nllb_backend": "transformers",
"nllb_size": "600M"
}
translation_params = update_with_kwargs(translation_params, kwargs)
@@ -175,5 +175,5 @@ def online_translation_factory(args, translation_model):
#should be at speaker level in the future:
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from whisperlivekit.translation.translation import OnlineTranslation
from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -300,7 +300,7 @@ def parse_args():
simulstreaming_group.add_argument(
"--nllb-backend",
type=str,
default="ctranslate2",
default="transformers",
help="transformers or ctranslate2",
)

View File

@@ -222,7 +222,7 @@ class SimulStreamingASR():
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
print('Simulstreaming will use MLX whisper to increase encoding speed.')
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:

View File

@@ -1,159 +0,0 @@
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Tuple, Optional
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
device = model.device
bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn")
def compute_common_prefix_tokens(
prev_tokens: torch.Tensor,
new_tokens: torch.Tensor,
tokenizer: AutoTokenizer,
sep: str = " "
) -> torch.Tensor:
if prev_tokens is None or len(prev_tokens) == 0:
return new_tokens
prev_text = tokenizer.decode(prev_tokens, skip_special_tokens=True)
new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
if not prev_text or not new_text:
return new_tokens
prev_words = prev_text.split(sep)
new_words = new_text.split(sep)
common_word_count = 0
for i in range(min(len(prev_words), len(new_words))):
if prev_words[i] == new_words[i]:
common_word_count += 1
else:
break
if common_word_count == 0:
if len(new_tokens) > 0:
return new_tokens[:1]
return new_tokens
if common_word_count == len(prev_words) and len(prev_words) == len(new_words):
return new_tokens
common_prefix_text = sep.join(new_words[:common_word_count])
for token_idx in range(1, len(new_tokens) + 1):
decoded = tokenizer.decode(new_tokens[:token_idx], skip_special_tokens=True)
if decoded == common_prefix_text:
return new_tokens[:token_idx]
if len(decoded) > len(common_prefix_text):
return new_tokens[:max(1, token_idx - 1)]
return new_tokens
def manual_generate(
encoder_outputs: torch.Tensor,
attention_mask: torch.Tensor,
forced_bos_token_id: int,
max_length: int = 50,
eos_token_id: Optional[int] = None
) -> Tuple[torch.Tensor, Optional[Tuple]]:
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
with torch.no_grad():
generated_tokens = model.generate(
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
forced_bos_token_id=forced_bos_token_id,
)
return generated_tokens, None
def continue_generation_with_cache(
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
prefix_tokens: torch.Tensor,
max_new_tokens: int = 50,
eos_token_id: Optional[int] = None
) -> torch.Tensor:
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
prefix_tokens = prefix_tokens.to(device)
if prefix_tokens.dim() == 1:
prefix_tokens = prefix_tokens.unsqueeze(0)
with torch.no_grad():
decoder_out = model.model.decoder(
input_ids=prefix_tokens,
encoder_hidden_states=encoder_hidden_states,
use_cache=True,
return_dict=True,
)
prefix_logits = model.lm_head(decoder_out.last_hidden_state)
past_key_values = decoder_out.past_key_values
next_token_id = torch.argmax(prefix_logits[:, -1, :], dim=-1).unsqueeze(-1)
if next_token_id.item() == eos_token_id:
return prefix_tokens
generated_tokens = torch.cat([prefix_tokens, next_token_id], dim=-1)
tokens_to_generate = max_new_tokens - generated_tokens.shape[1]
for _ in range(tokens_to_generate):
decoder_out = model.model.decoder(
input_ids=next_token_id,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
logits = model.lm_head(decoder_out.last_hidden_state)
next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
past_key_values = decoder_out.past_key_values
if next_token_id.item() == eos_token_id:
break
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
return generated_tokens
if __name__ == "__main__":
src_text_0 = "Have you noticed how accurate"
src_text_1 = " LLM are now that GPU have became more powerful"
inputs_0 = tokenizer(src_text_0, return_tensors="pt").to(device)
encoder_outputs_0 = model.get_encoder()(**inputs_0)
translation_tokens_0, _ = manual_generate(
encoder_outputs=encoder_outputs_0,
attention_mask=inputs_0['attention_mask'],
forced_bos_token_id=bos_token_id,
)
print(f"translation 0: {tokenizer.decode(translation_tokens_0[0], skip_special_tokens=True)}")
src_text_full = src_text_0 + src_text_1
inputs_1 = tokenizer(src_text_full, return_tensors="pt").to(device)
encoder_outputs_1 = model.get_encoder()(**inputs_1)
translation_tokens_1 = continue_generation_with_cache(
encoder_hidden_states=encoder_outputs_1.last_hidden_state,
attention_mask=inputs_1['attention_mask'],
prefix_tokens=translation_tokens_0[0].clone(),
max_new_tokens=200
)
print(f"1: {tokenizer.decode(translation_tokens_1[0], skip_special_tokens=True)}")

View File

@@ -1,264 +0,0 @@
LANGUAGES = [
{"name": "Acehnese (Arabic script)", "nllb": "ace_Arab", "language_code": "ace_Arab"},
{"name": "Acehnese (Latin script)", "nllb": "ace_Latn", "language_code": "ace_Latn"},
{"name": "Mesopotamian Arabic", "nllb": "acm_Arab", "language_code": "acm_Arab"},
{"name": "Ta'izzi-Adeni Arabic", "nllb": "acq_Arab", "language_code": "acq_Arab"},
{"name": "Tunisian Arabic", "nllb": "aeb_Arab", "language_code": "aeb_Arab"},
{"name": "Afrikaans", "nllb": "afr_Latn", "language_code": "af"},
{"name": "South Levantine Arabic", "nllb": "ajp_Arab", "language_code": "ajp_Arab"},
{"name": "Akan", "nllb": "aka_Latn", "language_code": "ak"},
{"name": "Tosk Albanian", "nllb": "als_Latn", "language_code": "als"},
{"name": "Amharic", "nllb": "amh_Ethi", "language_code": "am"},
{"name": "North Levantine Arabic", "nllb": "apc_Arab", "language_code": "apc_Arab"},
{"name": "Modern Standard Arabic", "nllb": "arb_Arab", "language_code": "ar"},
{"name": "Modern Standard Arabic (Romanized)", "nllb": "arb_Latn", "language_code": "arb_Latn"},
{"name": "Najdi Arabic", "nllb": "ars_Arab", "language_code": "ars_Arab"},
{"name": "Moroccan Arabic", "nllb": "ary_Arab", "language_code": "ary_Arab"},
{"name": "Egyptian Arabic", "nllb": "arz_Arab", "language_code": "arz_Arab"},
{"name": "Assamese", "nllb": "asm_Beng", "language_code": "as"},
{"name": "Asturian", "nllb": "ast_Latn", "language_code": "ast"},
{"name": "Awadhi", "nllb": "awa_Deva", "language_code": "awa"},
{"name": "Central Aymara", "nllb": "ayr_Latn", "language_code": "ay"},
{"name": "South Azerbaijani", "nllb": "azb_Arab", "language_code": "azb"},
{"name": "North Azerbaijani", "nllb": "azj_Latn", "language_code": "az"},
{"name": "Bashkir", "nllb": "bak_Cyrl", "language_code": "ba"},
{"name": "Bambara", "nllb": "bam_Latn", "language_code": "bm"},
{"name": "Balinese", "nllb": "ban_Latn", "language_code": "ban"},
{"name": "Belarusian", "nllb": "bel_Cyrl", "language_code": "be"},
{"name": "Bemba", "nllb": "bem_Latn", "language_code": "bem"},
{"name": "Bengali", "nllb": "ben_Beng", "language_code": "bn"},
{"name": "Bhojpuri", "nllb": "bho_Deva", "language_code": "bho"},
{"name": "Banjar (Arabic script)", "nllb": "bjn_Arab", "language_code": "bjn_Arab"},
{"name": "Banjar (Latin script)", "nllb": "bjn_Latn", "language_code": "bjn_Latn"},
{"name": "Standard Tibetan", "nllb": "bod_Tibt", "language_code": "bo"},
{"name": "Bosnian", "nllb": "bos_Latn", "language_code": "bs"},
{"name": "Buginese", "nllb": "bug_Latn", "language_code": "bug"},
{"name": "Bulgarian", "nllb": "bul_Cyrl", "language_code": "bg"},
{"name": "Catalan", "nllb": "cat_Latn", "language_code": "ca"},
{"name": "Cebuano", "nllb": "ceb_Latn", "language_code": "ceb"},
{"name": "Czech", "nllb": "ces_Latn", "language_code": "cs"},
{"name": "Chokwe", "nllb": "cjk_Latn", "language_code": "cjk"},
{"name": "Central Kurdish", "nllb": "ckb_Arab", "language_code": "ckb"},
{"name": "Crimean Tatar", "nllb": "crh_Latn", "language_code": "crh"},
{"name": "Welsh", "nllb": "cym_Latn", "language_code": "cy"},
{"name": "Danish", "nllb": "dan_Latn", "language_code": "da"},
{"name": "German", "nllb": "deu_Latn", "language_code": "de"},
{"name": "Southwestern Dinka", "nllb": "dik_Latn", "language_code": "dik"},
{"name": "Dyula", "nllb": "dyu_Latn", "language_code": "dyu"},
{"name": "Dzongkha", "nllb": "dzo_Tibt", "language_code": "dz"},
{"name": "Greek", "nllb": "ell_Grek", "language_code": "el"},
{"name": "English", "nllb": "eng_Latn", "language_code": "en"},
{"name": "Esperanto", "nllb": "epo_Latn", "language_code": "eo"},
{"name": "Estonian", "nllb": "est_Latn", "language_code": "et"},
{"name": "Basque", "nllb": "eus_Latn", "language_code": "eu"},
{"name": "Ewe", "nllb": "ewe_Latn", "language_code": "ee"},
{"name": "Faroese", "nllb": "fao_Latn", "language_code": "fo"},
{"name": "Fijian", "nllb": "fij_Latn", "language_code": "fj"},
{"name": "Finnish", "nllb": "fin_Latn", "language_code": "fi"},
{"name": "Fon", "nllb": "fon_Latn", "language_code": "fon"},
{"name": "French", "nllb": "fra_Latn", "language_code": "fr"},
{"name": "Friulian", "nllb": "fur_Latn", "language_code": "fur-IT"},
{"name": "Nigerian Fulfulde", "nllb": "fuv_Latn", "language_code": "fuv"},
{"name": "West Central Oromo", "nllb": "gaz_Latn", "language_code": "om"},
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "language_code": "gd"},
{"name": "Irish", "nllb": "gle_Latn", "language_code": "ga-IE"},
{"name": "Galician", "nllb": "glg_Latn", "language_code": "gl"},
{"name": "Guarani", "nllb": "grn_Latn", "language_code": "gn"},
{"name": "Gujarati", "nllb": "guj_Gujr", "language_code": "gu-IN"},
{"name": "Haitian Creole", "nllb": "hat_Latn", "language_code": "ht"},
{"name": "Hausa", "nllb": "hau_Latn", "language_code": "ha"},
{"name": "Hebrew", "nllb": "heb_Hebr", "language_code": "he"},
{"name": "Hindi", "nllb": "hin_Deva", "language_code": "hi"},
{"name": "Chhattisgarhi", "nllb": "hne_Deva", "language_code": "hne"},
{"name": "Croatian", "nllb": "hrv_Latn", "language_code": "hr"},
{"name": "Hungarian", "nllb": "hun_Latn", "language_code": "hu"},
{"name": "Armenian", "nllb": "hye_Armn", "language_code": "hy-AM"},
{"name": "Igbo", "nllb": "ibo_Latn", "language_code": "ig"},
{"name": "Ilocano", "nllb": "ilo_Latn", "language_code": "ilo"},
{"name": "Indonesian", "nllb": "ind_Latn", "language_code": "id"},
{"name": "Icelandic", "nllb": "isl_Latn", "language_code": "is"},
{"name": "Italian", "nllb": "ita_Latn", "language_code": "it"},
{"name": "Javanese", "nllb": "jav_Latn", "language_code": "jv"},
{"name": "Japanese", "nllb": "jpn_Jpan", "language_code": "ja"},
{"name": "Kabyle", "nllb": "kab_Latn", "language_code": "kab"},
{"name": "Jingpho", "nllb": "kac_Latn", "language_code": "kac"},
{"name": "Kamba", "nllb": "kam_Latn", "language_code": "kam"},
{"name": "Kannada", "nllb": "kan_Knda", "language_code": "kn"},
{"name": "Kashmiri (Arabic script)", "nllb": "kas_Arab", "language_code": "kas_Arab"},
{"name": "Kashmiri (Devanagari script)", "nllb": "kas_Deva", "language_code": "kas_Deva"},
{"name": "Georgian", "nllb": "kat_Geor", "language_code": "ka"},
{"name": "Kazakh", "nllb": "kaz_Cyrl", "language_code": "kk"},
{"name": "Kabiyè", "nllb": "kbp_Latn", "language_code": "kbp"},
{"name": "Kabuverdianu", "nllb": "kea_Latn", "language_code": "kea"},
{"name": "Halh Mongolian", "nllb": "khk_Cyrl", "language_code": "mn"},
{"name": "Khmer", "nllb": "khm_Khmr", "language_code": "km"},
{"name": "Kikuyu", "nllb": "kik_Latn", "language_code": "ki"},
{"name": "Kinyarwanda", "nllb": "kin_Latn", "language_code": "rw"},
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "language_code": "ky"},
{"name": "Kimbundu", "nllb": "kmb_Latn", "language_code": "kmb"},
{"name": "Northern Kurdish", "nllb": "kmr_Latn", "language_code": "kmr"},
{"name": "Central Kanuri (Arabic script)", "nllb": "knc_Arab", "language_code": "knc_Arab"},
{"name": "Central Kanuri (Latin script)", "nllb": "knc_Latn", "language_code": "knc_Latn"},
{"name": "Kikongo", "nllb": "kon_Latn", "language_code": "kg"},
{"name": "Korean", "nllb": "kor_Hang", "language_code": "ko"},
{"name": "Lao", "nllb": "lao_Laoo", "language_code": "lo"},
{"name": "Ligurian", "nllb": "lij_Latn", "language_code": "lij"},
{"name": "Limburgish", "nllb": "lim_Latn", "language_code": "li"},
{"name": "Lingala", "nllb": "lin_Latn", "language_code": "ln"},
{"name": "Lithuanian", "nllb": "lit_Latn", "language_code": "lt"},
{"name": "Lombard", "nllb": "lmo_Latn", "language_code": "lmo"},
{"name": "Latgalian", "nllb": "ltg_Latn", "language_code": "ltg"},
{"name": "Luxembourgish", "nllb": "ltz_Latn", "language_code": "lb"},
{"name": "Luba-Kasai", "nllb": "lua_Latn", "language_code": "lua"},
{"name": "Ganda", "nllb": "lug_Latn", "language_code": "lg"},
{"name": "Luo", "nllb": "luo_Latn", "language_code": "luo"},
{"name": "Mizo", "nllb": "lus_Latn", "language_code": "lus"},
{"name": "Standard Latvian", "nllb": "lvs_Latn", "language_code": "lv"},
{"name": "Magahi", "nllb": "mag_Deva", "language_code": "mag"},
{"name": "Maithili", "nllb": "mai_Deva", "language_code": "mai"},
{"name": "Malayalam", "nllb": "mal_Mlym", "language_code": "ml-IN"},
{"name": "Marathi", "nllb": "mar_Deva", "language_code": "mr"},
{"name": "Minangkabau (Arabic script)", "nllb": "min_Arab", "language_code": "min_Arab"},
{"name": "Minangkabau (Latin script)", "nllb": "min_Latn", "language_code": "min_Latn"},
{"name": "Macedonian", "nllb": "mkd_Cyrl", "language_code": "mk"},
{"name": "Maltese", "nllb": "mlt_Latn", "language_code": "mt"},
{"name": "Meitei (Bengali script)", "nllb": "mni_Beng", "language_code": "mni"},
{"name": "Mossi", "nllb": "mos_Latn", "language_code": "mos"},
{"name": "Maori", "nllb": "mri_Latn", "language_code": "mi"},
{"name": "Burmese", "nllb": "mya_Mymr", "language_code": "my"},
{"name": "Dutch", "nllb": "nld_Latn", "language_code": "nl"},
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "language_code": "nn-NO"},
{"name": "Norwegian Bokmål", "nllb": "nob_Latn", "language_code": "nb"},
{"name": "Nepali", "nllb": "npi_Deva", "language_code": "ne-NP"},
{"name": "Northern Sotho", "nllb": "nso_Latn", "language_code": "nso"},
{"name": "Nuer", "nllb": "nus_Latn", "language_code": "nus"},
{"name": "Nyanja", "nllb": "nya_Latn", "language_code": "ny"},
{"name": "Occitan", "nllb": "oci_Latn", "language_code": "oc"},
{"name": "Odia", "nllb": "ory_Orya", "language_code": "or"},
{"name": "Pangasinan", "nllb": "pag_Latn", "language_code": "pag"},
{"name": "Eastern Panjabi", "nllb": "pan_Guru", "language_code": "pa"},
{"name": "Papiamento", "nllb": "pap_Latn", "language_code": "pap"},
{"name": "Southern Pashto", "nllb": "pbt_Arab", "language_code": "pbt"},
{"name": "Western Persian", "nllb": "pes_Arab", "language_code": "fa"},
{"name": "Plateau Malagasy", "nllb": "plt_Latn", "language_code": "mg"},
{"name": "Polish", "nllb": "pol_Latn", "language_code": "pl"},
{"name": "Portuguese", "nllb": "por_Latn", "language_code": "pt-PT"},
{"name": "Dari", "nllb": "prs_Arab", "language_code": "fa-AF"},
{"name": "Ayacucho Quechua", "nllb": "quy_Latn", "language_code": "qu"},
{"name": "Romanian", "nllb": "ron_Latn", "language_code": "ro"},
{"name": "Rundi", "nllb": "run_Latn", "language_code": "rn"},
{"name": "Russian", "nllb": "rus_Cyrl", "language_code": "ru"},
{"name": "Sango", "nllb": "sag_Latn", "language_code": "sg"},
{"name": "Sanskrit", "nllb": "san_Deva", "language_code": "sa"},
{"name": "Santali", "nllb": "sat_Olck", "language_code": "sat"},
{"name": "Sicilian", "nllb": "scn_Latn", "language_code": "scn"},
{"name": "Shan", "nllb": "shn_Mymr", "language_code": "shn"},
{"name": "Sinhala", "nllb": "sin_Sinh", "language_code": "si-LK"},
{"name": "Slovak", "nllb": "slk_Latn", "language_code": "sk"},
{"name": "Slovenian", "nllb": "slv_Latn", "language_code": "sl"},
{"name": "Samoan", "nllb": "smo_Latn", "language_code": "sm"},
{"name": "Shona", "nllb": "sna_Latn", "language_code": "sn"},
{"name": "Sindhi", "nllb": "snd_Arab", "language_code": "sd"},
{"name": "Somali", "nllb": "som_Latn", "language_code": "so"},
{"name": "Southern Sotho", "nllb": "sot_Latn", "language_code": "st"},
{"name": "Spanish", "nllb": "spa_Latn", "language_code": "es-ES"},
{"name": "Sardinian", "nllb": "srd_Latn", "language_code": "sc"},
{"name": "Serbian", "nllb": "srp_Cyrl", "language_code": "sr"},
{"name": "Swati", "nllb": "ssw_Latn", "language_code": "ss"},
{"name": "Sundanese", "nllb": "sun_Latn", "language_code": "su"},
{"name": "Swedish", "nllb": "swe_Latn", "language_code": "sv-SE"},
{"name": "Swahili", "nllb": "swh_Latn", "language_code": "sw"},
{"name": "Silesian", "nllb": "szl_Latn", "language_code": "szl"},
{"name": "Tamil", "nllb": "tam_Taml", "language_code": "ta"},
{"name": "Tamasheq (Latin script)", "nllb": "taq_Latn", "language_code": "taq_Latn"},
{"name": "Tamasheq (Tifinagh script)", "nllb": "taq_Tfng", "language_code": "taq_Tfng"},
{"name": "Tatar", "nllb": "tat_Cyrl", "language_code": "tt-RU"},
{"name": "Telugu", "nllb": "tel_Telu", "language_code": "te"},
{"name": "Tajik", "nllb": "tgk_Cyrl", "language_code": "tg"},
{"name": "Tagalog", "nllb": "tgl_Latn", "language_code": "tl"},
{"name": "Thai", "nllb": "tha_Thai", "language_code": "th"},
{"name": "Tigrinya", "nllb": "tir_Ethi", "language_code": "ti"},
{"name": "Tok Pisin", "nllb": "tpi_Latn", "language_code": "tpi"},
{"name": "Tswana", "nllb": "tsn_Latn", "language_code": "tn"},
{"name": "Tsonga", "nllb": "tso_Latn", "language_code": "ts"},
{"name": "Turkmen", "nllb": "tuk_Latn", "language_code": "tk"},
{"name": "Tumbuka", "nllb": "tum_Latn", "language_code": "tum"},
{"name": "Turkish", "nllb": "tur_Latn", "language_code": "tr"},
{"name": "Twi", "nllb": "twi_Latn", "language_code": "tw"},
{"name": "Central Atlas Tamazight", "nllb": "tzm_Tfng", "language_code": "tzm"},
{"name": "Uyghur", "nllb": "uig_Arab", "language_code": "ug"},
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "language_code": "uk"},
{"name": "Umbundu", "nllb": "umb_Latn", "language_code": "umb"},
{"name": "Urdu", "nllb": "urd_Arab", "language_code": "ur"},
{"name": "Northern Uzbek", "nllb": "uzn_Latn", "language_code": "uz"},
{"name": "Venetian", "nllb": "vec_Latn", "language_code": "vec"},
{"name": "Vietnamese", "nllb": "vie_Latn", "language_code": "vi"},
{"name": "Waray", "nllb": "war_Latn", "language_code": "war"},
{"name": "Wolof", "nllb": "wol_Latn", "language_code": "wo"},
{"name": "Xhosa", "nllb": "xho_Latn", "language_code": "xh"},
{"name": "Eastern Yiddish", "nllb": "ydd_Hebr", "language_code": "yi"},
{"name": "Yoruba", "nllb": "yor_Latn", "language_code": "yo"},
{"name": "Yue Chinese", "nllb": "yue_Hant", "language_code": "yue"},
{"name": "Chinese (Simplified)", "nllb": "zho_Hans", "language_code": "zh-CN"},
{"name": "Chinese (Traditional)", "nllb": "zho_Hant", "language_code": "zh-TW"},
{"name": "Standard Malay", "nllb": "zsm_Latn", "language_code": "ms"},
{"name": "Zulu", "nllb": "zul_Latn", "language_code": "zu"},
]
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
NAME_TO_LANGUAGE_CODE = {lang["name"]: lang["language_code"] for lang in LANGUAGES}
LANGUAGE_CODE_TO_NLLB = {lang["language_code"]: lang["nllb"] for lang in LANGUAGES}
NLLB_TO_LANGUAGE_CODE = {lang["nllb"]: lang["language_code"] for lang in LANGUAGES}
LANGUAGE_CODE_TO_NAME = {lang["language_code"]: lang["name"] for lang in LANGUAGES}
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
def get_nllb_code(language_code_code):
return LANGUAGE_CODE_TO_NLLB.get(language_code_code, None)
def get_language_code_code(nllb_code):
return NLLB_TO_LANGUAGE_CODE.get(nllb_code)
def get_language_name_by_language_code(language_code_code):
return LANGUAGE_CODE_TO_NAME.get(language_code_code)
def get_language_name_by_nllb(nllb_code):
return NLLB_TO_NAME.get(nllb_code)
def get_language_info(identifier, identifier_type="auto"):
if identifier_type == "auto":
for lang in LANGUAGES:
if (lang["name"].lower() == identifier.lower() or
lang["nllb"] == identifier or
lang["language_code"] == identifier):
return lang
elif identifier_type == "name":
for lang in LANGUAGES:
if lang["name"].lower() == identifier.lower():
return lang
elif identifier_type == "nllb":
for lang in LANGUAGES:
if lang["nllb"] == identifier:
return lang
elif identifier_type == "language_code":
for lang in LANGUAGES:
if lang["language_code"] == identifier:
return lang
return None
def list_all_languages():
return [lang["name"] for lang in LANGUAGES]
def list_all_nllb_codes():
return [lang["nllb"] for lang in LANGUAGES]
def list_all_language_code_codes():
return [lang["language_code"] for lang in LANGUAGES]

View File

@@ -1,171 +0,0 @@
import logging
import time
import ctranslate2
import torch
import transformers
from dataclasses import dataclass, field
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation
logger = logging.getLogger(__name__)
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
# sentence is not finished.
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
device: str
tokenizer: dict = field(default_factory=dict)
backend_type: str = 'ctranslate2'
nllb_size: str = '600M'
def get_tokenizer(self, input_lang):
if not self.tokenizer.get(input_lang, False):
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
f"facebook/nllb-200-distilled-{self.nllb_size}",
src_lang=input_lang,
clean_up_tokenization_spaces=True
)
return self.tokenizer[input_lang]
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
device = "cuda" if torch.cuda.is_available() else "cpu"
if nllb_backend=='ctranslate2':
model = f'nllb-200-distilled-{nllb_size}-ctranslate2'
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model)
translator = ctranslate2.Translator(model,device=device)
elif nllb_backend=='transformers':
model = f"facebook/nllb-200-distilled-{nllb_size}"
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model)
tokenizer = dict()
for src_lang in src_langs:
if src_lang != 'auto':
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True)
translation_model = TranslationModel(
translator=translator,
tokenizer=tokenizer,
backend_type=nllb_backend,
device = device,
nllb_size = nllb_size
)
for src_lang in src_langs:
if src_lang != 'auto':
translation_model.get_tokenizer(src_lang)
return translation_model
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.input_buffer = []
self.len_processed_buffer = 0
self.translation_remaining = Translation()
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
def compute_common_prefix(self, results):
#we dont want want to prune the result for the moment.
if not self.input_buffer:
self.input_buffer = results
else:
for i in range(min(len(self.input_buffer), len(results))):
if self.input_buffer[i] != results[i]:
self.commited.extend(self.input_buffer[:i])
self.input_buffer = results[i:]
def translate(self, input, input_lang, output_lang):
if not input:
return ""
nllb_output_lang = get_nllb_code(output_lang)
tokenizer = self.translation_model.get_tokenizer(input_lang)
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
if self.translation_model.backend_type == 'ctranslate2':
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
target = results[0].hypotheses[0][1:]
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
else:
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return result
def translate_tokens(self, tokens):
if tokens:
text = ' '.join([token.text for token in tokens])
start = tokens[0].start
end = tokens[-1].end
if self.input_languages[0] == 'auto':
input_lang = tokens[0].detected_language
else:
input_lang = self.input_languages[0]
translated_text = self.translate(text,
input_lang,
self.output_languages[0]
)
translation = Translation(
text=translated_text,
start=start,
end=end,
)
return translation
return None
def insert_tokens(self, tokens):
self.input_buffer.extend(tokens)
pass
def process(self):
i = 0
if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.input_buffer):
if self.input_buffer[i].is_punctuation():
translation_sentence = self.translate_tokens(self.input_buffer[:i+1])
self.validated.append(translation_sentence)
self.input_buffer = self.input_buffer[i+1:]
i = 0
else:
i+=1
self.translation_remaining = self.translate_tokens(self.input_buffer)
self.len_processed_buffer = len(self.input_buffer)
return self.validated, [self.translation_remaining]
def insert_silence(self, silence_duration: float):
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
self.input_buffer = []
self.validated += [self.translation_remaining]
if __name__ == '__main__':
output_lang = 'fr'
input_lang = "en"
test_string = """
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
"""
test = test_string.split(' ')
step = len(test) // 3
shared_model = load_model([input_lang], nllb_backend='transformers')
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
beg_inference = time.time()
for id in range(5):
val = test[id*step : (id+1)*step]
val_str = ' '.join(val)
result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang)
print(result)
print('inference time:', time.time() - beg_inference)