diff --git a/README.md b/README.md index 58276b7..2fbd1da 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ Real-time transcription directly to your browser, with a ready-to-use backend+se #### Powered by Leading Research: -- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufalSimul/Streaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408) -- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages. +- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408) +- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages. - [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf) - [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization - [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization @@ -68,9 +68,9 @@ Go to `chrome-extension` for instructions. | Optional | `pip install` | |-----------|-------------| -| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | -| **Apple Silicon optimized backend** | `mlx-whisper` | -| **NLLB Translation** | `huggingface_hub` & `transformers` | +| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | +| **Apple Silicon optimizations** | `mlx-whisper` | +| **Translation** | `nllw` | | *[Not recommanded]* Speaker diarization with Diart | `diart` | | *[Not recommanded]* Original Whisper backend | `whisper` | | *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` | diff --git a/docs/supported_languages.md b/docs/supported_languages.md index a04443e..e6a26f9 100644 --- a/docs/supported_languages.md +++ b/docs/supported_languages.md @@ -26,7 +26,7 @@ whisperlivekit-server --target-language fra_Latn ### Python API ```python -from whisperlivekit.translation import get_language_info +from nllw.translation import get_language_info # Get language information by name lang_info = get_language_info("French") diff --git a/pyproject.toml b/pyproject.toml index 5666f88..1da2451 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ ] [project.optional-dependencies] -translation = ["transformers", "huggingface_hub"] +translation = ["nllw"] sentence_tokenizer = ["mosestokenizer", "wtpsplit"] [project.urls] @@ -60,7 +60,6 @@ packages = [ "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", - "whisperlivekit.translation", "whisperlivekit.vad_models" ] diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 955a3aa..f5a4173 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -141,9 +141,9 @@ class TranscriptionEngine: if self.args.lan == 'auto' and self.args.backend != "simulstreaming": raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') else: - from whisperlivekit.translation.translation import load_model + from nllw import load_model translation_params = { - "nllb_backend": "ctranslate2", + "nllb_backend": "transformers", "nllb_size": "600M" } translation_params = update_with_kwargs(translation_params, kwargs) @@ -175,5 +175,5 @@ def online_translation_factory(args, translation_model): #should be at speaker level in the future: #one shared nllb model for all speaker #one tokenizer per speaker/language - from whisperlivekit.translation.translation import OnlineTranslation + from nllw import OnlineTranslation return OnlineTranslation(translation_model, [args.lan], [args.target_language]) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index fe50023..e7db271 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -300,7 +300,7 @@ def parse_args(): simulstreaming_group.add_argument( "--nllb-backend", type=str, - default="ctranslate2", + default="transformers", help="transformers or ctranslate2", ) diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index ee248ef..0a8ef11 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -222,7 +222,7 @@ class SimulStreamingASR(): self.mlx_encoder, self.fw_encoder = None, None if not self.disable_fast_encoder: if HAS_MLX_WHISPER: - print('Simulstreaming will use MLX whisper for a faster encoder.') + print('Simulstreaming will use MLX whisper to increase encoding speed.') if self.model_path and compatible_whisper_mlx: mlx_model = self.model_path else: diff --git a/whisperlivekit/translation/__init__.py b/whisperlivekit/translation/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/whisperlivekit/translation/fast_translation.py b/whisperlivekit/translation/fast_translation.py deleted file mode 100644 index 1226e46..0000000 --- a/whisperlivekit/translation/fast_translation.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -from typing import Tuple, Optional - -model_name = "facebook/nllb-200-distilled-600M" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") - -device = model.device -bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn") - - -def compute_common_prefix_tokens( - prev_tokens: torch.Tensor, - new_tokens: torch.Tensor, - tokenizer: AutoTokenizer, - sep: str = " " -) -> torch.Tensor: - if prev_tokens is None or len(prev_tokens) == 0: - return new_tokens - - prev_text = tokenizer.decode(prev_tokens, skip_special_tokens=True) - new_text = tokenizer.decode(new_tokens, skip_special_tokens=True) - - if not prev_text or not new_text: - return new_tokens - - prev_words = prev_text.split(sep) - new_words = new_text.split(sep) - - common_word_count = 0 - for i in range(min(len(prev_words), len(new_words))): - if prev_words[i] == new_words[i]: - common_word_count += 1 - else: - break - - if common_word_count == 0: - if len(new_tokens) > 0: - return new_tokens[:1] - return new_tokens - - if common_word_count == len(prev_words) and len(prev_words) == len(new_words): - return new_tokens - - common_prefix_text = sep.join(new_words[:common_word_count]) - - for token_idx in range(1, len(new_tokens) + 1): - decoded = tokenizer.decode(new_tokens[:token_idx], skip_special_tokens=True) - if decoded == common_prefix_text: - return new_tokens[:token_idx] - if len(decoded) > len(common_prefix_text): - return new_tokens[:max(1, token_idx - 1)] - - return new_tokens - - -def manual_generate( - encoder_outputs: torch.Tensor, - attention_mask: torch.Tensor, - forced_bos_token_id: int, - max_length: int = 50, - eos_token_id: Optional[int] = None -) -> Tuple[torch.Tensor, Optional[Tuple]]: - if eos_token_id is None: - eos_token_id = tokenizer.eos_token_id - - with torch.no_grad(): - generated_tokens = model.generate( - encoder_outputs=encoder_outputs, - attention_mask=attention_mask, - forced_bos_token_id=forced_bos_token_id, - ) - - return generated_tokens, None - - -def continue_generation_with_cache( - encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - prefix_tokens: torch.Tensor, - max_new_tokens: int = 50, - eos_token_id: Optional[int] = None -) -> torch.Tensor: - if eos_token_id is None: - eos_token_id = tokenizer.eos_token_id - - prefix_tokens = prefix_tokens.to(device) - - if prefix_tokens.dim() == 1: - prefix_tokens = prefix_tokens.unsqueeze(0) - - with torch.no_grad(): - decoder_out = model.model.decoder( - input_ids=prefix_tokens, - encoder_hidden_states=encoder_hidden_states, - use_cache=True, - return_dict=True, - ) - prefix_logits = model.lm_head(decoder_out.last_hidden_state) - past_key_values = decoder_out.past_key_values - - next_token_id = torch.argmax(prefix_logits[:, -1, :], dim=-1).unsqueeze(-1) - - if next_token_id.item() == eos_token_id: - return prefix_tokens - - generated_tokens = torch.cat([prefix_tokens, next_token_id], dim=-1) - - tokens_to_generate = max_new_tokens - generated_tokens.shape[1] - - for _ in range(tokens_to_generate): - decoder_out = model.model.decoder( - input_ids=next_token_id, - encoder_hidden_states=encoder_hidden_states, - past_key_values=past_key_values, - use_cache=True, - return_dict=True, - ) - logits = model.lm_head(decoder_out.last_hidden_state) - - next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1) - past_key_values = decoder_out.past_key_values - - if next_token_id.item() == eos_token_id: - break - - generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1) - - return generated_tokens - - -if __name__ == "__main__": - src_text_0 = "Have you noticed how accurate" - src_text_1 = " LLM are now that GPU have became more powerful" - - inputs_0 = tokenizer(src_text_0, return_tensors="pt").to(device) - encoder_outputs_0 = model.get_encoder()(**inputs_0) - - translation_tokens_0, _ = manual_generate( - encoder_outputs=encoder_outputs_0, - attention_mask=inputs_0['attention_mask'], - forced_bos_token_id=bos_token_id, - ) - - print(f"translation 0: {tokenizer.decode(translation_tokens_0[0], skip_special_tokens=True)}") - - src_text_full = src_text_0 + src_text_1 - inputs_1 = tokenizer(src_text_full, return_tensors="pt").to(device) - encoder_outputs_1 = model.get_encoder()(**inputs_1) - - translation_tokens_1 = continue_generation_with_cache( - encoder_hidden_states=encoder_outputs_1.last_hidden_state, - attention_mask=inputs_1['attention_mask'], - prefix_tokens=translation_tokens_0[0].clone(), - max_new_tokens=200 - ) - - print(f"1: {tokenizer.decode(translation_tokens_1[0], skip_special_tokens=True)}") diff --git a/whisperlivekit/translation/mapping_languages.py b/whisperlivekit/translation/mapping_languages.py deleted file mode 100644 index c7bff7e..0000000 --- a/whisperlivekit/translation/mapping_languages.py +++ /dev/null @@ -1,264 +0,0 @@ -LANGUAGES = [ - {"name": "Acehnese (Arabic script)", "nllb": "ace_Arab", "language_code": "ace_Arab"}, - {"name": "Acehnese (Latin script)", "nllb": "ace_Latn", "language_code": "ace_Latn"}, - {"name": "Mesopotamian Arabic", "nllb": "acm_Arab", "language_code": "acm_Arab"}, - {"name": "Ta'izzi-Adeni Arabic", "nllb": "acq_Arab", "language_code": "acq_Arab"}, - {"name": "Tunisian Arabic", "nllb": "aeb_Arab", "language_code": "aeb_Arab"}, - {"name": "Afrikaans", "nllb": "afr_Latn", "language_code": "af"}, - {"name": "South Levantine Arabic", "nllb": "ajp_Arab", "language_code": "ajp_Arab"}, - {"name": "Akan", "nllb": "aka_Latn", "language_code": "ak"}, - {"name": "Tosk Albanian", "nllb": "als_Latn", "language_code": "als"}, - {"name": "Amharic", "nllb": "amh_Ethi", "language_code": "am"}, - {"name": "North Levantine Arabic", "nllb": "apc_Arab", "language_code": "apc_Arab"}, - {"name": "Modern Standard Arabic", "nllb": "arb_Arab", "language_code": "ar"}, - {"name": "Modern Standard Arabic (Romanized)", "nllb": "arb_Latn", "language_code": "arb_Latn"}, - {"name": "Najdi Arabic", "nllb": "ars_Arab", "language_code": "ars_Arab"}, - {"name": "Moroccan Arabic", "nllb": "ary_Arab", "language_code": "ary_Arab"}, - {"name": "Egyptian Arabic", "nllb": "arz_Arab", "language_code": "arz_Arab"}, - {"name": "Assamese", "nllb": "asm_Beng", "language_code": "as"}, - {"name": "Asturian", "nllb": "ast_Latn", "language_code": "ast"}, - {"name": "Awadhi", "nllb": "awa_Deva", "language_code": "awa"}, - {"name": "Central Aymara", "nllb": "ayr_Latn", "language_code": "ay"}, - {"name": "South Azerbaijani", "nllb": "azb_Arab", "language_code": "azb"}, - {"name": "North Azerbaijani", "nllb": "azj_Latn", "language_code": "az"}, - {"name": "Bashkir", "nllb": "bak_Cyrl", "language_code": "ba"}, - {"name": "Bambara", "nllb": "bam_Latn", "language_code": "bm"}, - {"name": "Balinese", "nllb": "ban_Latn", "language_code": "ban"}, - {"name": "Belarusian", "nllb": "bel_Cyrl", "language_code": "be"}, - {"name": "Bemba", "nllb": "bem_Latn", "language_code": "bem"}, - {"name": "Bengali", "nllb": "ben_Beng", "language_code": "bn"}, - {"name": "Bhojpuri", "nllb": "bho_Deva", "language_code": "bho"}, - {"name": "Banjar (Arabic script)", "nllb": "bjn_Arab", "language_code": "bjn_Arab"}, - {"name": "Banjar (Latin script)", "nllb": "bjn_Latn", "language_code": "bjn_Latn"}, - {"name": "Standard Tibetan", "nllb": "bod_Tibt", "language_code": "bo"}, - {"name": "Bosnian", "nllb": "bos_Latn", "language_code": "bs"}, - {"name": "Buginese", "nllb": "bug_Latn", "language_code": "bug"}, - {"name": "Bulgarian", "nllb": "bul_Cyrl", "language_code": "bg"}, - {"name": "Catalan", "nllb": "cat_Latn", "language_code": "ca"}, - {"name": "Cebuano", "nllb": "ceb_Latn", "language_code": "ceb"}, - {"name": "Czech", "nllb": "ces_Latn", "language_code": "cs"}, - {"name": "Chokwe", "nllb": "cjk_Latn", "language_code": "cjk"}, - {"name": "Central Kurdish", "nllb": "ckb_Arab", "language_code": "ckb"}, - {"name": "Crimean Tatar", "nllb": "crh_Latn", "language_code": "crh"}, - {"name": "Welsh", "nllb": "cym_Latn", "language_code": "cy"}, - {"name": "Danish", "nllb": "dan_Latn", "language_code": "da"}, - {"name": "German", "nllb": "deu_Latn", "language_code": "de"}, - {"name": "Southwestern Dinka", "nllb": "dik_Latn", "language_code": "dik"}, - {"name": "Dyula", "nllb": "dyu_Latn", "language_code": "dyu"}, - {"name": "Dzongkha", "nllb": "dzo_Tibt", "language_code": "dz"}, - {"name": "Greek", "nllb": "ell_Grek", "language_code": "el"}, - {"name": "English", "nllb": "eng_Latn", "language_code": "en"}, - {"name": "Esperanto", "nllb": "epo_Latn", "language_code": "eo"}, - {"name": "Estonian", "nllb": "est_Latn", "language_code": "et"}, - {"name": "Basque", "nllb": "eus_Latn", "language_code": "eu"}, - {"name": "Ewe", "nllb": "ewe_Latn", "language_code": "ee"}, - {"name": "Faroese", "nllb": "fao_Latn", "language_code": "fo"}, - {"name": "Fijian", "nllb": "fij_Latn", "language_code": "fj"}, - {"name": "Finnish", "nllb": "fin_Latn", "language_code": "fi"}, - {"name": "Fon", "nllb": "fon_Latn", "language_code": "fon"}, - {"name": "French", "nllb": "fra_Latn", "language_code": "fr"}, - {"name": "Friulian", "nllb": "fur_Latn", "language_code": "fur-IT"}, - {"name": "Nigerian Fulfulde", "nllb": "fuv_Latn", "language_code": "fuv"}, - {"name": "West Central Oromo", "nllb": "gaz_Latn", "language_code": "om"}, - {"name": "Scottish Gaelic", "nllb": "gla_Latn", "language_code": "gd"}, - {"name": "Irish", "nllb": "gle_Latn", "language_code": "ga-IE"}, - {"name": "Galician", "nllb": "glg_Latn", "language_code": "gl"}, - {"name": "Guarani", "nllb": "grn_Latn", "language_code": "gn"}, - {"name": "Gujarati", "nllb": "guj_Gujr", "language_code": "gu-IN"}, - {"name": "Haitian Creole", "nllb": "hat_Latn", "language_code": "ht"}, - {"name": "Hausa", "nllb": "hau_Latn", "language_code": "ha"}, - {"name": "Hebrew", "nllb": "heb_Hebr", "language_code": "he"}, - {"name": "Hindi", "nllb": "hin_Deva", "language_code": "hi"}, - {"name": "Chhattisgarhi", "nllb": "hne_Deva", "language_code": "hne"}, - {"name": "Croatian", "nllb": "hrv_Latn", "language_code": "hr"}, - {"name": "Hungarian", "nllb": "hun_Latn", "language_code": "hu"}, - {"name": "Armenian", "nllb": "hye_Armn", "language_code": "hy-AM"}, - {"name": "Igbo", "nllb": "ibo_Latn", "language_code": "ig"}, - {"name": "Ilocano", "nllb": "ilo_Latn", "language_code": "ilo"}, - {"name": "Indonesian", "nllb": "ind_Latn", "language_code": "id"}, - {"name": "Icelandic", "nllb": "isl_Latn", "language_code": "is"}, - {"name": "Italian", "nllb": "ita_Latn", "language_code": "it"}, - {"name": "Javanese", "nllb": "jav_Latn", "language_code": "jv"}, - {"name": "Japanese", "nllb": "jpn_Jpan", "language_code": "ja"}, - {"name": "Kabyle", "nllb": "kab_Latn", "language_code": "kab"}, - {"name": "Jingpho", "nllb": "kac_Latn", "language_code": "kac"}, - {"name": "Kamba", "nllb": "kam_Latn", "language_code": "kam"}, - {"name": "Kannada", "nllb": "kan_Knda", "language_code": "kn"}, - {"name": "Kashmiri (Arabic script)", "nllb": "kas_Arab", "language_code": "kas_Arab"}, - {"name": "Kashmiri (Devanagari script)", "nllb": "kas_Deva", "language_code": "kas_Deva"}, - {"name": "Georgian", "nllb": "kat_Geor", "language_code": "ka"}, - {"name": "Kazakh", "nllb": "kaz_Cyrl", "language_code": "kk"}, - {"name": "Kabiyè", "nllb": "kbp_Latn", "language_code": "kbp"}, - {"name": "Kabuverdianu", "nllb": "kea_Latn", "language_code": "kea"}, - {"name": "Halh Mongolian", "nllb": "khk_Cyrl", "language_code": "mn"}, - {"name": "Khmer", "nllb": "khm_Khmr", "language_code": "km"}, - {"name": "Kikuyu", "nllb": "kik_Latn", "language_code": "ki"}, - {"name": "Kinyarwanda", "nllb": "kin_Latn", "language_code": "rw"}, - {"name": "Kyrgyz", "nllb": "kir_Cyrl", "language_code": "ky"}, - {"name": "Kimbundu", "nllb": "kmb_Latn", "language_code": "kmb"}, - {"name": "Northern Kurdish", "nllb": "kmr_Latn", "language_code": "kmr"}, - {"name": "Central Kanuri (Arabic script)", "nllb": "knc_Arab", "language_code": "knc_Arab"}, - {"name": "Central Kanuri (Latin script)", "nllb": "knc_Latn", "language_code": "knc_Latn"}, - {"name": "Kikongo", "nllb": "kon_Latn", "language_code": "kg"}, - {"name": "Korean", "nllb": "kor_Hang", "language_code": "ko"}, - {"name": "Lao", "nllb": "lao_Laoo", "language_code": "lo"}, - {"name": "Ligurian", "nllb": "lij_Latn", "language_code": "lij"}, - {"name": "Limburgish", "nllb": "lim_Latn", "language_code": "li"}, - {"name": "Lingala", "nllb": "lin_Latn", "language_code": "ln"}, - {"name": "Lithuanian", "nllb": "lit_Latn", "language_code": "lt"}, - {"name": "Lombard", "nllb": "lmo_Latn", "language_code": "lmo"}, - {"name": "Latgalian", "nllb": "ltg_Latn", "language_code": "ltg"}, - {"name": "Luxembourgish", "nllb": "ltz_Latn", "language_code": "lb"}, - {"name": "Luba-Kasai", "nllb": "lua_Latn", "language_code": "lua"}, - {"name": "Ganda", "nllb": "lug_Latn", "language_code": "lg"}, - {"name": "Luo", "nllb": "luo_Latn", "language_code": "luo"}, - {"name": "Mizo", "nllb": "lus_Latn", "language_code": "lus"}, - {"name": "Standard Latvian", "nllb": "lvs_Latn", "language_code": "lv"}, - {"name": "Magahi", "nllb": "mag_Deva", "language_code": "mag"}, - {"name": "Maithili", "nllb": "mai_Deva", "language_code": "mai"}, - {"name": "Malayalam", "nllb": "mal_Mlym", "language_code": "ml-IN"}, - {"name": "Marathi", "nllb": "mar_Deva", "language_code": "mr"}, - {"name": "Minangkabau (Arabic script)", "nllb": "min_Arab", "language_code": "min_Arab"}, - {"name": "Minangkabau (Latin script)", "nllb": "min_Latn", "language_code": "min_Latn"}, - {"name": "Macedonian", "nllb": "mkd_Cyrl", "language_code": "mk"}, - {"name": "Maltese", "nllb": "mlt_Latn", "language_code": "mt"}, - {"name": "Meitei (Bengali script)", "nllb": "mni_Beng", "language_code": "mni"}, - {"name": "Mossi", "nllb": "mos_Latn", "language_code": "mos"}, - {"name": "Maori", "nllb": "mri_Latn", "language_code": "mi"}, - {"name": "Burmese", "nllb": "mya_Mymr", "language_code": "my"}, - {"name": "Dutch", "nllb": "nld_Latn", "language_code": "nl"}, - {"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "language_code": "nn-NO"}, - {"name": "Norwegian Bokmål", "nllb": "nob_Latn", "language_code": "nb"}, - {"name": "Nepali", "nllb": "npi_Deva", "language_code": "ne-NP"}, - {"name": "Northern Sotho", "nllb": "nso_Latn", "language_code": "nso"}, - {"name": "Nuer", "nllb": "nus_Latn", "language_code": "nus"}, - {"name": "Nyanja", "nllb": "nya_Latn", "language_code": "ny"}, - {"name": "Occitan", "nllb": "oci_Latn", "language_code": "oc"}, - {"name": "Odia", "nllb": "ory_Orya", "language_code": "or"}, - {"name": "Pangasinan", "nllb": "pag_Latn", "language_code": "pag"}, - {"name": "Eastern Panjabi", "nllb": "pan_Guru", "language_code": "pa"}, - {"name": "Papiamento", "nllb": "pap_Latn", "language_code": "pap"}, - {"name": "Southern Pashto", "nllb": "pbt_Arab", "language_code": "pbt"}, - {"name": "Western Persian", "nllb": "pes_Arab", "language_code": "fa"}, - {"name": "Plateau Malagasy", "nllb": "plt_Latn", "language_code": "mg"}, - {"name": "Polish", "nllb": "pol_Latn", "language_code": "pl"}, - {"name": "Portuguese", "nllb": "por_Latn", "language_code": "pt-PT"}, - {"name": "Dari", "nllb": "prs_Arab", "language_code": "fa-AF"}, - {"name": "Ayacucho Quechua", "nllb": "quy_Latn", "language_code": "qu"}, - {"name": "Romanian", "nllb": "ron_Latn", "language_code": "ro"}, - {"name": "Rundi", "nllb": "run_Latn", "language_code": "rn"}, - {"name": "Russian", "nllb": "rus_Cyrl", "language_code": "ru"}, - {"name": "Sango", "nllb": "sag_Latn", "language_code": "sg"}, - {"name": "Sanskrit", "nllb": "san_Deva", "language_code": "sa"}, - {"name": "Santali", "nllb": "sat_Olck", "language_code": "sat"}, - {"name": "Sicilian", "nllb": "scn_Latn", "language_code": "scn"}, - {"name": "Shan", "nllb": "shn_Mymr", "language_code": "shn"}, - {"name": "Sinhala", "nllb": "sin_Sinh", "language_code": "si-LK"}, - {"name": "Slovak", "nllb": "slk_Latn", "language_code": "sk"}, - {"name": "Slovenian", "nllb": "slv_Latn", "language_code": "sl"}, - {"name": "Samoan", "nllb": "smo_Latn", "language_code": "sm"}, - {"name": "Shona", "nllb": "sna_Latn", "language_code": "sn"}, - {"name": "Sindhi", "nllb": "snd_Arab", "language_code": "sd"}, - {"name": "Somali", "nllb": "som_Latn", "language_code": "so"}, - {"name": "Southern Sotho", "nllb": "sot_Latn", "language_code": "st"}, - {"name": "Spanish", "nllb": "spa_Latn", "language_code": "es-ES"}, - {"name": "Sardinian", "nllb": "srd_Latn", "language_code": "sc"}, - {"name": "Serbian", "nllb": "srp_Cyrl", "language_code": "sr"}, - {"name": "Swati", "nllb": "ssw_Latn", "language_code": "ss"}, - {"name": "Sundanese", "nllb": "sun_Latn", "language_code": "su"}, - {"name": "Swedish", "nllb": "swe_Latn", "language_code": "sv-SE"}, - {"name": "Swahili", "nllb": "swh_Latn", "language_code": "sw"}, - {"name": "Silesian", "nllb": "szl_Latn", "language_code": "szl"}, - {"name": "Tamil", "nllb": "tam_Taml", "language_code": "ta"}, - {"name": "Tamasheq (Latin script)", "nllb": "taq_Latn", "language_code": "taq_Latn"}, - {"name": "Tamasheq (Tifinagh script)", "nllb": "taq_Tfng", "language_code": "taq_Tfng"}, - {"name": "Tatar", "nllb": "tat_Cyrl", "language_code": "tt-RU"}, - {"name": "Telugu", "nllb": "tel_Telu", "language_code": "te"}, - {"name": "Tajik", "nllb": "tgk_Cyrl", "language_code": "tg"}, - {"name": "Tagalog", "nllb": "tgl_Latn", "language_code": "tl"}, - {"name": "Thai", "nllb": "tha_Thai", "language_code": "th"}, - {"name": "Tigrinya", "nllb": "tir_Ethi", "language_code": "ti"}, - {"name": "Tok Pisin", "nllb": "tpi_Latn", "language_code": "tpi"}, - {"name": "Tswana", "nllb": "tsn_Latn", "language_code": "tn"}, - {"name": "Tsonga", "nllb": "tso_Latn", "language_code": "ts"}, - {"name": "Turkmen", "nllb": "tuk_Latn", "language_code": "tk"}, - {"name": "Tumbuka", "nllb": "tum_Latn", "language_code": "tum"}, - {"name": "Turkish", "nllb": "tur_Latn", "language_code": "tr"}, - {"name": "Twi", "nllb": "twi_Latn", "language_code": "tw"}, - {"name": "Central Atlas Tamazight", "nllb": "tzm_Tfng", "language_code": "tzm"}, - {"name": "Uyghur", "nllb": "uig_Arab", "language_code": "ug"}, - {"name": "Ukrainian", "nllb": "ukr_Cyrl", "language_code": "uk"}, - {"name": "Umbundu", "nllb": "umb_Latn", "language_code": "umb"}, - {"name": "Urdu", "nllb": "urd_Arab", "language_code": "ur"}, - {"name": "Northern Uzbek", "nllb": "uzn_Latn", "language_code": "uz"}, - {"name": "Venetian", "nllb": "vec_Latn", "language_code": "vec"}, - {"name": "Vietnamese", "nllb": "vie_Latn", "language_code": "vi"}, - {"name": "Waray", "nllb": "war_Latn", "language_code": "war"}, - {"name": "Wolof", "nllb": "wol_Latn", "language_code": "wo"}, - {"name": "Xhosa", "nllb": "xho_Latn", "language_code": "xh"}, - {"name": "Eastern Yiddish", "nllb": "ydd_Hebr", "language_code": "yi"}, - {"name": "Yoruba", "nllb": "yor_Latn", "language_code": "yo"}, - {"name": "Yue Chinese", "nllb": "yue_Hant", "language_code": "yue"}, - {"name": "Chinese (Simplified)", "nllb": "zho_Hans", "language_code": "zh-CN"}, - {"name": "Chinese (Traditional)", "nllb": "zho_Hant", "language_code": "zh-TW"}, - {"name": "Standard Malay", "nllb": "zsm_Latn", "language_code": "ms"}, - {"name": "Zulu", "nllb": "zul_Latn", "language_code": "zu"}, -] - -NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES} -NAME_TO_LANGUAGE_CODE = {lang["name"]: lang["language_code"] for lang in LANGUAGES} -LANGUAGE_CODE_TO_NLLB = {lang["language_code"]: lang["nllb"] for lang in LANGUAGES} -NLLB_TO_LANGUAGE_CODE = {lang["nllb"]: lang["language_code"] for lang in LANGUAGES} -LANGUAGE_CODE_TO_NAME = {lang["language_code"]: lang["name"] for lang in LANGUAGES} -NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES} - - -def get_nllb_code(language_code_code): - return LANGUAGE_CODE_TO_NLLB.get(language_code_code, None) - - -def get_language_code_code(nllb_code): - return NLLB_TO_LANGUAGE_CODE.get(nllb_code) - - -def get_language_name_by_language_code(language_code_code): - return LANGUAGE_CODE_TO_NAME.get(language_code_code) - - -def get_language_name_by_nllb(nllb_code): - return NLLB_TO_NAME.get(nllb_code) - - -def get_language_info(identifier, identifier_type="auto"): - if identifier_type == "auto": - for lang in LANGUAGES: - if (lang["name"].lower() == identifier.lower() or - lang["nllb"] == identifier or - lang["language_code"] == identifier): - return lang - elif identifier_type == "name": - for lang in LANGUAGES: - if lang["name"].lower() == identifier.lower(): - return lang - elif identifier_type == "nllb": - for lang in LANGUAGES: - if lang["nllb"] == identifier: - return lang - elif identifier_type == "language_code": - for lang in LANGUAGES: - if lang["language_code"] == identifier: - return lang - - return None - - -def list_all_languages(): - return [lang["name"] for lang in LANGUAGES] - - -def list_all_nllb_codes(): - return [lang["nllb"] for lang in LANGUAGES] - - -def list_all_language_code_codes(): - return [lang["language_code"] for lang in LANGUAGES] diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py deleted file mode 100644 index f1d5c14..0000000 --- a/whisperlivekit/translation/translation.py +++ /dev/null @@ -1,171 +0,0 @@ -import logging -import time -import ctranslate2 -import torch -import transformers -from dataclasses import dataclass, field -import huggingface_hub -from whisperlivekit.translation.mapping_languages import get_nllb_code -from whisperlivekit.timed_objects import Translation - -logger = logging.getLogger(__name__) - -#In diarization case, we may want to translate just one speaker, or at least start the sentences there - -MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous -# sentence is not finished. - -@dataclass -class TranslationModel(): - translator: ctranslate2.Translator - device: str - tokenizer: dict = field(default_factory=dict) - backend_type: str = 'ctranslate2' - nllb_size: str = '600M' - - def get_tokenizer(self, input_lang): - if not self.tokenizer.get(input_lang, False): - self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained( - f"facebook/nllb-200-distilled-{self.nllb_size}", - src_lang=input_lang, - clean_up_tokenization_spaces=True - ) - return self.tokenizer[input_lang] - - -def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'): - device = "cuda" if torch.cuda.is_available() else "cpu" - - if nllb_backend=='ctranslate2': - model = f'nllb-200-distilled-{nllb_size}-ctranslate2' - MODEL_GUY = 'entai2965' - huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model) - translator = ctranslate2.Translator(model,device=device) - elif nllb_backend=='transformers': - model = f"facebook/nllb-200-distilled-{nllb_size}" - translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model) - tokenizer = dict() - for src_lang in src_langs: - if src_lang != 'auto': - tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True) - - translation_model = TranslationModel( - translator=translator, - tokenizer=tokenizer, - backend_type=nllb_backend, - device = device, - nllb_size = nllb_size - ) - for src_lang in src_langs: - if src_lang != 'auto': - translation_model.get_tokenizer(src_lang) - return translation_model - -class OnlineTranslation: - def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): - self.input_buffer = [] - self.len_processed_buffer = 0 - self.translation_remaining = Translation() - self.validated = [] - self.translation_pending_validation = '' - self.translation_model = translation_model - self.input_languages = input_languages - self.output_languages = output_languages - - def compute_common_prefix(self, results): - #we dont want want to prune the result for the moment. - if not self.input_buffer: - self.input_buffer = results - else: - for i in range(min(len(self.input_buffer), len(results))): - if self.input_buffer[i] != results[i]: - self.commited.extend(self.input_buffer[:i]) - self.input_buffer = results[i:] - - def translate(self, input, input_lang, output_lang): - if not input: - return "" - nllb_output_lang = get_nllb_code(output_lang) - - tokenizer = self.translation_model.get_tokenizer(input_lang) - tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device) - - if self.translation_model.backend_type == 'ctranslate2': - source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0]) - results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) - target = results[0].hypotheses[0][1:] - result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) - else: - translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang)) - result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] - return result - - def translate_tokens(self, tokens): - if tokens: - text = ' '.join([token.text for token in tokens]) - start = tokens[0].start - end = tokens[-1].end - if self.input_languages[0] == 'auto': - input_lang = tokens[0].detected_language - else: - input_lang = self.input_languages[0] - - translated_text = self.translate(text, - input_lang, - self.output_languages[0] - ) - translation = Translation( - text=translated_text, - start=start, - end=end, - ) - return translation - return None - - - def insert_tokens(self, tokens): - self.input_buffer.extend(tokens) - pass - - def process(self): - i = 0 - if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process - return self.validated + [self.translation_remaining] - while i < len(self.input_buffer): - if self.input_buffer[i].is_punctuation(): - translation_sentence = self.translate_tokens(self.input_buffer[:i+1]) - self.validated.append(translation_sentence) - self.input_buffer = self.input_buffer[i+1:] - i = 0 - else: - i+=1 - self.translation_remaining = self.translate_tokens(self.input_buffer) - self.len_processed_buffer = len(self.input_buffer) - return self.validated, [self.translation_remaining] - - def insert_silence(self, silence_duration: float): - if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER: - self.input_buffer = [] - self.validated += [self.translation_remaining] - -if __name__ == '__main__': - output_lang = 'fr' - input_lang = "en" - - - test_string = """ - Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now? - """ - test = test_string.split(' ') - step = len(test) // 3 - - shared_model = load_model([input_lang], nllb_backend='transformers') - online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) - - beg_inference = time.time() - for id in range(5): - val = test[id*step : (id+1)*step] - val_str = ' '.join(val) - result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang) - print(result) - print('inference time:', time.time() - beg_inference) \ No newline at end of file