mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Use optional new separate NLLW package for translation
This commit is contained in:
10
README.md
10
README.md
@@ -18,8 +18,8 @@ Real-time transcription directly to your browser, with a ready-to-use backend+se
|
||||
|
||||
#### Powered by Leading Research:
|
||||
|
||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufalSimul/Streaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLB](https://arxiv.org/abs/2207.04672), ([distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2)) (2024) - Translation to more than 100 languages.
|
||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
@@ -68,9 +68,9 @@ Go to `chrome-extension` for instructions.
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimized backend** | `mlx-whisper` |
|
||||
| **NLLB Translation** | `huggingface_hub` & `transformers` |
|
||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||
| **Translation** | `nllw` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||
|
||||
@@ -26,7 +26,7 @@ whisperlivekit-server --target-language fra_Latn
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from whisperlivekit.translation import get_language_info
|
||||
from nllw.translation import get_language_info
|
||||
|
||||
# Get language information by name
|
||||
lang_info = get_language_info("French")
|
||||
|
||||
@@ -41,7 +41,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
translation = ["transformers", "huggingface_hub"]
|
||||
translation = ["nllw"]
|
||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||
|
||||
[project.urls]
|
||||
@@ -60,7 +60,6 @@ packages = [
|
||||
"whisperlivekit.simul_whisper.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.whisper_streaming_custom",
|
||||
"whisperlivekit.translation",
|
||||
"whisperlivekit.vad_models"
|
||||
]
|
||||
|
||||
|
||||
@@ -141,9 +141,9 @@ class TranscriptionEngine:
|
||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
from nllw import load_model
|
||||
translation_params = {
|
||||
"nllb_backend": "ctranslate2",
|
||||
"nllb_backend": "transformers",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||
@@ -175,5 +175,5 @@ def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
from nllw import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
|
||||
@@ -300,7 +300,7 @@ def parse_args():
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="ctranslate2",
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class SimulStreamingASR():
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self.model_path and compatible_whisper_mlx:
|
||||
mlx_model = self.model_path
|
||||
else:
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
from typing import Tuple, Optional
|
||||
|
||||
model_name = "facebook/nllb-200-distilled-600M"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
device = model.device
|
||||
bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn")
|
||||
|
||||
|
||||
def compute_common_prefix_tokens(
|
||||
prev_tokens: torch.Tensor,
|
||||
new_tokens: torch.Tensor,
|
||||
tokenizer: AutoTokenizer,
|
||||
sep: str = " "
|
||||
) -> torch.Tensor:
|
||||
if prev_tokens is None or len(prev_tokens) == 0:
|
||||
return new_tokens
|
||||
|
||||
prev_text = tokenizer.decode(prev_tokens, skip_special_tokens=True)
|
||||
new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
|
||||
if not prev_text or not new_text:
|
||||
return new_tokens
|
||||
|
||||
prev_words = prev_text.split(sep)
|
||||
new_words = new_text.split(sep)
|
||||
|
||||
common_word_count = 0
|
||||
for i in range(min(len(prev_words), len(new_words))):
|
||||
if prev_words[i] == new_words[i]:
|
||||
common_word_count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if common_word_count == 0:
|
||||
if len(new_tokens) > 0:
|
||||
return new_tokens[:1]
|
||||
return new_tokens
|
||||
|
||||
if common_word_count == len(prev_words) and len(prev_words) == len(new_words):
|
||||
return new_tokens
|
||||
|
||||
common_prefix_text = sep.join(new_words[:common_word_count])
|
||||
|
||||
for token_idx in range(1, len(new_tokens) + 1):
|
||||
decoded = tokenizer.decode(new_tokens[:token_idx], skip_special_tokens=True)
|
||||
if decoded == common_prefix_text:
|
||||
return new_tokens[:token_idx]
|
||||
if len(decoded) > len(common_prefix_text):
|
||||
return new_tokens[:max(1, token_idx - 1)]
|
||||
|
||||
return new_tokens
|
||||
|
||||
|
||||
def manual_generate(
|
||||
encoder_outputs: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
forced_bos_token_id: int,
|
||||
max_length: int = 50,
|
||||
eos_token_id: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
with torch.no_grad():
|
||||
generated_tokens = model.generate(
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
)
|
||||
|
||||
return generated_tokens, None
|
||||
|
||||
|
||||
def continue_generation_with_cache(
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
prefix_tokens: torch.Tensor,
|
||||
max_new_tokens: int = 50,
|
||||
eos_token_id: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
prefix_tokens = prefix_tokens.to(device)
|
||||
|
||||
if prefix_tokens.dim() == 1:
|
||||
prefix_tokens = prefix_tokens.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
decoder_out = model.model.decoder(
|
||||
input_ids=prefix_tokens,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
prefix_logits = model.lm_head(decoder_out.last_hidden_state)
|
||||
past_key_values = decoder_out.past_key_values
|
||||
|
||||
next_token_id = torch.argmax(prefix_logits[:, -1, :], dim=-1).unsqueeze(-1)
|
||||
|
||||
if next_token_id.item() == eos_token_id:
|
||||
return prefix_tokens
|
||||
|
||||
generated_tokens = torch.cat([prefix_tokens, next_token_id], dim=-1)
|
||||
|
||||
tokens_to_generate = max_new_tokens - generated_tokens.shape[1]
|
||||
|
||||
for _ in range(tokens_to_generate):
|
||||
decoder_out = model.model.decoder(
|
||||
input_ids=next_token_id,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
logits = model.lm_head(decoder_out.last_hidden_state)
|
||||
|
||||
next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
|
||||
past_key_values = decoder_out.past_key_values
|
||||
|
||||
if next_token_id.item() == eos_token_id:
|
||||
break
|
||||
|
||||
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
src_text_0 = "Have you noticed how accurate"
|
||||
src_text_1 = " LLM are now that GPU have became more powerful"
|
||||
|
||||
inputs_0 = tokenizer(src_text_0, return_tensors="pt").to(device)
|
||||
encoder_outputs_0 = model.get_encoder()(**inputs_0)
|
||||
|
||||
translation_tokens_0, _ = manual_generate(
|
||||
encoder_outputs=encoder_outputs_0,
|
||||
attention_mask=inputs_0['attention_mask'],
|
||||
forced_bos_token_id=bos_token_id,
|
||||
)
|
||||
|
||||
print(f"translation 0: {tokenizer.decode(translation_tokens_0[0], skip_special_tokens=True)}")
|
||||
|
||||
src_text_full = src_text_0 + src_text_1
|
||||
inputs_1 = tokenizer(src_text_full, return_tensors="pt").to(device)
|
||||
encoder_outputs_1 = model.get_encoder()(**inputs_1)
|
||||
|
||||
translation_tokens_1 = continue_generation_with_cache(
|
||||
encoder_hidden_states=encoder_outputs_1.last_hidden_state,
|
||||
attention_mask=inputs_1['attention_mask'],
|
||||
prefix_tokens=translation_tokens_0[0].clone(),
|
||||
max_new_tokens=200
|
||||
)
|
||||
|
||||
print(f"1: {tokenizer.decode(translation_tokens_1[0], skip_special_tokens=True)}")
|
||||
@@ -1,264 +0,0 @@
|
||||
LANGUAGES = [
|
||||
{"name": "Acehnese (Arabic script)", "nllb": "ace_Arab", "language_code": "ace_Arab"},
|
||||
{"name": "Acehnese (Latin script)", "nllb": "ace_Latn", "language_code": "ace_Latn"},
|
||||
{"name": "Mesopotamian Arabic", "nllb": "acm_Arab", "language_code": "acm_Arab"},
|
||||
{"name": "Ta'izzi-Adeni Arabic", "nllb": "acq_Arab", "language_code": "acq_Arab"},
|
||||
{"name": "Tunisian Arabic", "nllb": "aeb_Arab", "language_code": "aeb_Arab"},
|
||||
{"name": "Afrikaans", "nllb": "afr_Latn", "language_code": "af"},
|
||||
{"name": "South Levantine Arabic", "nllb": "ajp_Arab", "language_code": "ajp_Arab"},
|
||||
{"name": "Akan", "nllb": "aka_Latn", "language_code": "ak"},
|
||||
{"name": "Tosk Albanian", "nllb": "als_Latn", "language_code": "als"},
|
||||
{"name": "Amharic", "nllb": "amh_Ethi", "language_code": "am"},
|
||||
{"name": "North Levantine Arabic", "nllb": "apc_Arab", "language_code": "apc_Arab"},
|
||||
{"name": "Modern Standard Arabic", "nllb": "arb_Arab", "language_code": "ar"},
|
||||
{"name": "Modern Standard Arabic (Romanized)", "nllb": "arb_Latn", "language_code": "arb_Latn"},
|
||||
{"name": "Najdi Arabic", "nllb": "ars_Arab", "language_code": "ars_Arab"},
|
||||
{"name": "Moroccan Arabic", "nllb": "ary_Arab", "language_code": "ary_Arab"},
|
||||
{"name": "Egyptian Arabic", "nllb": "arz_Arab", "language_code": "arz_Arab"},
|
||||
{"name": "Assamese", "nllb": "asm_Beng", "language_code": "as"},
|
||||
{"name": "Asturian", "nllb": "ast_Latn", "language_code": "ast"},
|
||||
{"name": "Awadhi", "nllb": "awa_Deva", "language_code": "awa"},
|
||||
{"name": "Central Aymara", "nllb": "ayr_Latn", "language_code": "ay"},
|
||||
{"name": "South Azerbaijani", "nllb": "azb_Arab", "language_code": "azb"},
|
||||
{"name": "North Azerbaijani", "nllb": "azj_Latn", "language_code": "az"},
|
||||
{"name": "Bashkir", "nllb": "bak_Cyrl", "language_code": "ba"},
|
||||
{"name": "Bambara", "nllb": "bam_Latn", "language_code": "bm"},
|
||||
{"name": "Balinese", "nllb": "ban_Latn", "language_code": "ban"},
|
||||
{"name": "Belarusian", "nllb": "bel_Cyrl", "language_code": "be"},
|
||||
{"name": "Bemba", "nllb": "bem_Latn", "language_code": "bem"},
|
||||
{"name": "Bengali", "nllb": "ben_Beng", "language_code": "bn"},
|
||||
{"name": "Bhojpuri", "nllb": "bho_Deva", "language_code": "bho"},
|
||||
{"name": "Banjar (Arabic script)", "nllb": "bjn_Arab", "language_code": "bjn_Arab"},
|
||||
{"name": "Banjar (Latin script)", "nllb": "bjn_Latn", "language_code": "bjn_Latn"},
|
||||
{"name": "Standard Tibetan", "nllb": "bod_Tibt", "language_code": "bo"},
|
||||
{"name": "Bosnian", "nllb": "bos_Latn", "language_code": "bs"},
|
||||
{"name": "Buginese", "nllb": "bug_Latn", "language_code": "bug"},
|
||||
{"name": "Bulgarian", "nllb": "bul_Cyrl", "language_code": "bg"},
|
||||
{"name": "Catalan", "nllb": "cat_Latn", "language_code": "ca"},
|
||||
{"name": "Cebuano", "nllb": "ceb_Latn", "language_code": "ceb"},
|
||||
{"name": "Czech", "nllb": "ces_Latn", "language_code": "cs"},
|
||||
{"name": "Chokwe", "nllb": "cjk_Latn", "language_code": "cjk"},
|
||||
{"name": "Central Kurdish", "nllb": "ckb_Arab", "language_code": "ckb"},
|
||||
{"name": "Crimean Tatar", "nllb": "crh_Latn", "language_code": "crh"},
|
||||
{"name": "Welsh", "nllb": "cym_Latn", "language_code": "cy"},
|
||||
{"name": "Danish", "nllb": "dan_Latn", "language_code": "da"},
|
||||
{"name": "German", "nllb": "deu_Latn", "language_code": "de"},
|
||||
{"name": "Southwestern Dinka", "nllb": "dik_Latn", "language_code": "dik"},
|
||||
{"name": "Dyula", "nllb": "dyu_Latn", "language_code": "dyu"},
|
||||
{"name": "Dzongkha", "nllb": "dzo_Tibt", "language_code": "dz"},
|
||||
{"name": "Greek", "nllb": "ell_Grek", "language_code": "el"},
|
||||
{"name": "English", "nllb": "eng_Latn", "language_code": "en"},
|
||||
{"name": "Esperanto", "nllb": "epo_Latn", "language_code": "eo"},
|
||||
{"name": "Estonian", "nllb": "est_Latn", "language_code": "et"},
|
||||
{"name": "Basque", "nllb": "eus_Latn", "language_code": "eu"},
|
||||
{"name": "Ewe", "nllb": "ewe_Latn", "language_code": "ee"},
|
||||
{"name": "Faroese", "nllb": "fao_Latn", "language_code": "fo"},
|
||||
{"name": "Fijian", "nllb": "fij_Latn", "language_code": "fj"},
|
||||
{"name": "Finnish", "nllb": "fin_Latn", "language_code": "fi"},
|
||||
{"name": "Fon", "nllb": "fon_Latn", "language_code": "fon"},
|
||||
{"name": "French", "nllb": "fra_Latn", "language_code": "fr"},
|
||||
{"name": "Friulian", "nllb": "fur_Latn", "language_code": "fur-IT"},
|
||||
{"name": "Nigerian Fulfulde", "nllb": "fuv_Latn", "language_code": "fuv"},
|
||||
{"name": "West Central Oromo", "nllb": "gaz_Latn", "language_code": "om"},
|
||||
{"name": "Scottish Gaelic", "nllb": "gla_Latn", "language_code": "gd"},
|
||||
{"name": "Irish", "nllb": "gle_Latn", "language_code": "ga-IE"},
|
||||
{"name": "Galician", "nllb": "glg_Latn", "language_code": "gl"},
|
||||
{"name": "Guarani", "nllb": "grn_Latn", "language_code": "gn"},
|
||||
{"name": "Gujarati", "nllb": "guj_Gujr", "language_code": "gu-IN"},
|
||||
{"name": "Haitian Creole", "nllb": "hat_Latn", "language_code": "ht"},
|
||||
{"name": "Hausa", "nllb": "hau_Latn", "language_code": "ha"},
|
||||
{"name": "Hebrew", "nllb": "heb_Hebr", "language_code": "he"},
|
||||
{"name": "Hindi", "nllb": "hin_Deva", "language_code": "hi"},
|
||||
{"name": "Chhattisgarhi", "nllb": "hne_Deva", "language_code": "hne"},
|
||||
{"name": "Croatian", "nllb": "hrv_Latn", "language_code": "hr"},
|
||||
{"name": "Hungarian", "nllb": "hun_Latn", "language_code": "hu"},
|
||||
{"name": "Armenian", "nllb": "hye_Armn", "language_code": "hy-AM"},
|
||||
{"name": "Igbo", "nllb": "ibo_Latn", "language_code": "ig"},
|
||||
{"name": "Ilocano", "nllb": "ilo_Latn", "language_code": "ilo"},
|
||||
{"name": "Indonesian", "nllb": "ind_Latn", "language_code": "id"},
|
||||
{"name": "Icelandic", "nllb": "isl_Latn", "language_code": "is"},
|
||||
{"name": "Italian", "nllb": "ita_Latn", "language_code": "it"},
|
||||
{"name": "Javanese", "nllb": "jav_Latn", "language_code": "jv"},
|
||||
{"name": "Japanese", "nllb": "jpn_Jpan", "language_code": "ja"},
|
||||
{"name": "Kabyle", "nllb": "kab_Latn", "language_code": "kab"},
|
||||
{"name": "Jingpho", "nllb": "kac_Latn", "language_code": "kac"},
|
||||
{"name": "Kamba", "nllb": "kam_Latn", "language_code": "kam"},
|
||||
{"name": "Kannada", "nllb": "kan_Knda", "language_code": "kn"},
|
||||
{"name": "Kashmiri (Arabic script)", "nllb": "kas_Arab", "language_code": "kas_Arab"},
|
||||
{"name": "Kashmiri (Devanagari script)", "nllb": "kas_Deva", "language_code": "kas_Deva"},
|
||||
{"name": "Georgian", "nllb": "kat_Geor", "language_code": "ka"},
|
||||
{"name": "Kazakh", "nllb": "kaz_Cyrl", "language_code": "kk"},
|
||||
{"name": "Kabiyè", "nllb": "kbp_Latn", "language_code": "kbp"},
|
||||
{"name": "Kabuverdianu", "nllb": "kea_Latn", "language_code": "kea"},
|
||||
{"name": "Halh Mongolian", "nllb": "khk_Cyrl", "language_code": "mn"},
|
||||
{"name": "Khmer", "nllb": "khm_Khmr", "language_code": "km"},
|
||||
{"name": "Kikuyu", "nllb": "kik_Latn", "language_code": "ki"},
|
||||
{"name": "Kinyarwanda", "nllb": "kin_Latn", "language_code": "rw"},
|
||||
{"name": "Kyrgyz", "nllb": "kir_Cyrl", "language_code": "ky"},
|
||||
{"name": "Kimbundu", "nllb": "kmb_Latn", "language_code": "kmb"},
|
||||
{"name": "Northern Kurdish", "nllb": "kmr_Latn", "language_code": "kmr"},
|
||||
{"name": "Central Kanuri (Arabic script)", "nllb": "knc_Arab", "language_code": "knc_Arab"},
|
||||
{"name": "Central Kanuri (Latin script)", "nllb": "knc_Latn", "language_code": "knc_Latn"},
|
||||
{"name": "Kikongo", "nllb": "kon_Latn", "language_code": "kg"},
|
||||
{"name": "Korean", "nllb": "kor_Hang", "language_code": "ko"},
|
||||
{"name": "Lao", "nllb": "lao_Laoo", "language_code": "lo"},
|
||||
{"name": "Ligurian", "nllb": "lij_Latn", "language_code": "lij"},
|
||||
{"name": "Limburgish", "nllb": "lim_Latn", "language_code": "li"},
|
||||
{"name": "Lingala", "nllb": "lin_Latn", "language_code": "ln"},
|
||||
{"name": "Lithuanian", "nllb": "lit_Latn", "language_code": "lt"},
|
||||
{"name": "Lombard", "nllb": "lmo_Latn", "language_code": "lmo"},
|
||||
{"name": "Latgalian", "nllb": "ltg_Latn", "language_code": "ltg"},
|
||||
{"name": "Luxembourgish", "nllb": "ltz_Latn", "language_code": "lb"},
|
||||
{"name": "Luba-Kasai", "nllb": "lua_Latn", "language_code": "lua"},
|
||||
{"name": "Ganda", "nllb": "lug_Latn", "language_code": "lg"},
|
||||
{"name": "Luo", "nllb": "luo_Latn", "language_code": "luo"},
|
||||
{"name": "Mizo", "nllb": "lus_Latn", "language_code": "lus"},
|
||||
{"name": "Standard Latvian", "nllb": "lvs_Latn", "language_code": "lv"},
|
||||
{"name": "Magahi", "nllb": "mag_Deva", "language_code": "mag"},
|
||||
{"name": "Maithili", "nllb": "mai_Deva", "language_code": "mai"},
|
||||
{"name": "Malayalam", "nllb": "mal_Mlym", "language_code": "ml-IN"},
|
||||
{"name": "Marathi", "nllb": "mar_Deva", "language_code": "mr"},
|
||||
{"name": "Minangkabau (Arabic script)", "nllb": "min_Arab", "language_code": "min_Arab"},
|
||||
{"name": "Minangkabau (Latin script)", "nllb": "min_Latn", "language_code": "min_Latn"},
|
||||
{"name": "Macedonian", "nllb": "mkd_Cyrl", "language_code": "mk"},
|
||||
{"name": "Maltese", "nllb": "mlt_Latn", "language_code": "mt"},
|
||||
{"name": "Meitei (Bengali script)", "nllb": "mni_Beng", "language_code": "mni"},
|
||||
{"name": "Mossi", "nllb": "mos_Latn", "language_code": "mos"},
|
||||
{"name": "Maori", "nllb": "mri_Latn", "language_code": "mi"},
|
||||
{"name": "Burmese", "nllb": "mya_Mymr", "language_code": "my"},
|
||||
{"name": "Dutch", "nllb": "nld_Latn", "language_code": "nl"},
|
||||
{"name": "Norwegian Nynorsk", "nllb": "nno_Latn", "language_code": "nn-NO"},
|
||||
{"name": "Norwegian Bokmål", "nllb": "nob_Latn", "language_code": "nb"},
|
||||
{"name": "Nepali", "nllb": "npi_Deva", "language_code": "ne-NP"},
|
||||
{"name": "Northern Sotho", "nllb": "nso_Latn", "language_code": "nso"},
|
||||
{"name": "Nuer", "nllb": "nus_Latn", "language_code": "nus"},
|
||||
{"name": "Nyanja", "nllb": "nya_Latn", "language_code": "ny"},
|
||||
{"name": "Occitan", "nllb": "oci_Latn", "language_code": "oc"},
|
||||
{"name": "Odia", "nllb": "ory_Orya", "language_code": "or"},
|
||||
{"name": "Pangasinan", "nllb": "pag_Latn", "language_code": "pag"},
|
||||
{"name": "Eastern Panjabi", "nllb": "pan_Guru", "language_code": "pa"},
|
||||
{"name": "Papiamento", "nllb": "pap_Latn", "language_code": "pap"},
|
||||
{"name": "Southern Pashto", "nllb": "pbt_Arab", "language_code": "pbt"},
|
||||
{"name": "Western Persian", "nllb": "pes_Arab", "language_code": "fa"},
|
||||
{"name": "Plateau Malagasy", "nllb": "plt_Latn", "language_code": "mg"},
|
||||
{"name": "Polish", "nllb": "pol_Latn", "language_code": "pl"},
|
||||
{"name": "Portuguese", "nllb": "por_Latn", "language_code": "pt-PT"},
|
||||
{"name": "Dari", "nllb": "prs_Arab", "language_code": "fa-AF"},
|
||||
{"name": "Ayacucho Quechua", "nllb": "quy_Latn", "language_code": "qu"},
|
||||
{"name": "Romanian", "nllb": "ron_Latn", "language_code": "ro"},
|
||||
{"name": "Rundi", "nllb": "run_Latn", "language_code": "rn"},
|
||||
{"name": "Russian", "nllb": "rus_Cyrl", "language_code": "ru"},
|
||||
{"name": "Sango", "nllb": "sag_Latn", "language_code": "sg"},
|
||||
{"name": "Sanskrit", "nllb": "san_Deva", "language_code": "sa"},
|
||||
{"name": "Santali", "nllb": "sat_Olck", "language_code": "sat"},
|
||||
{"name": "Sicilian", "nllb": "scn_Latn", "language_code": "scn"},
|
||||
{"name": "Shan", "nllb": "shn_Mymr", "language_code": "shn"},
|
||||
{"name": "Sinhala", "nllb": "sin_Sinh", "language_code": "si-LK"},
|
||||
{"name": "Slovak", "nllb": "slk_Latn", "language_code": "sk"},
|
||||
{"name": "Slovenian", "nllb": "slv_Latn", "language_code": "sl"},
|
||||
{"name": "Samoan", "nllb": "smo_Latn", "language_code": "sm"},
|
||||
{"name": "Shona", "nllb": "sna_Latn", "language_code": "sn"},
|
||||
{"name": "Sindhi", "nllb": "snd_Arab", "language_code": "sd"},
|
||||
{"name": "Somali", "nllb": "som_Latn", "language_code": "so"},
|
||||
{"name": "Southern Sotho", "nllb": "sot_Latn", "language_code": "st"},
|
||||
{"name": "Spanish", "nllb": "spa_Latn", "language_code": "es-ES"},
|
||||
{"name": "Sardinian", "nllb": "srd_Latn", "language_code": "sc"},
|
||||
{"name": "Serbian", "nllb": "srp_Cyrl", "language_code": "sr"},
|
||||
{"name": "Swati", "nllb": "ssw_Latn", "language_code": "ss"},
|
||||
{"name": "Sundanese", "nllb": "sun_Latn", "language_code": "su"},
|
||||
{"name": "Swedish", "nllb": "swe_Latn", "language_code": "sv-SE"},
|
||||
{"name": "Swahili", "nllb": "swh_Latn", "language_code": "sw"},
|
||||
{"name": "Silesian", "nllb": "szl_Latn", "language_code": "szl"},
|
||||
{"name": "Tamil", "nllb": "tam_Taml", "language_code": "ta"},
|
||||
{"name": "Tamasheq (Latin script)", "nllb": "taq_Latn", "language_code": "taq_Latn"},
|
||||
{"name": "Tamasheq (Tifinagh script)", "nllb": "taq_Tfng", "language_code": "taq_Tfng"},
|
||||
{"name": "Tatar", "nllb": "tat_Cyrl", "language_code": "tt-RU"},
|
||||
{"name": "Telugu", "nllb": "tel_Telu", "language_code": "te"},
|
||||
{"name": "Tajik", "nllb": "tgk_Cyrl", "language_code": "tg"},
|
||||
{"name": "Tagalog", "nllb": "tgl_Latn", "language_code": "tl"},
|
||||
{"name": "Thai", "nllb": "tha_Thai", "language_code": "th"},
|
||||
{"name": "Tigrinya", "nllb": "tir_Ethi", "language_code": "ti"},
|
||||
{"name": "Tok Pisin", "nllb": "tpi_Latn", "language_code": "tpi"},
|
||||
{"name": "Tswana", "nllb": "tsn_Latn", "language_code": "tn"},
|
||||
{"name": "Tsonga", "nllb": "tso_Latn", "language_code": "ts"},
|
||||
{"name": "Turkmen", "nllb": "tuk_Latn", "language_code": "tk"},
|
||||
{"name": "Tumbuka", "nllb": "tum_Latn", "language_code": "tum"},
|
||||
{"name": "Turkish", "nllb": "tur_Latn", "language_code": "tr"},
|
||||
{"name": "Twi", "nllb": "twi_Latn", "language_code": "tw"},
|
||||
{"name": "Central Atlas Tamazight", "nllb": "tzm_Tfng", "language_code": "tzm"},
|
||||
{"name": "Uyghur", "nllb": "uig_Arab", "language_code": "ug"},
|
||||
{"name": "Ukrainian", "nllb": "ukr_Cyrl", "language_code": "uk"},
|
||||
{"name": "Umbundu", "nllb": "umb_Latn", "language_code": "umb"},
|
||||
{"name": "Urdu", "nllb": "urd_Arab", "language_code": "ur"},
|
||||
{"name": "Northern Uzbek", "nllb": "uzn_Latn", "language_code": "uz"},
|
||||
{"name": "Venetian", "nllb": "vec_Latn", "language_code": "vec"},
|
||||
{"name": "Vietnamese", "nllb": "vie_Latn", "language_code": "vi"},
|
||||
{"name": "Waray", "nllb": "war_Latn", "language_code": "war"},
|
||||
{"name": "Wolof", "nllb": "wol_Latn", "language_code": "wo"},
|
||||
{"name": "Xhosa", "nllb": "xho_Latn", "language_code": "xh"},
|
||||
{"name": "Eastern Yiddish", "nllb": "ydd_Hebr", "language_code": "yi"},
|
||||
{"name": "Yoruba", "nllb": "yor_Latn", "language_code": "yo"},
|
||||
{"name": "Yue Chinese", "nllb": "yue_Hant", "language_code": "yue"},
|
||||
{"name": "Chinese (Simplified)", "nllb": "zho_Hans", "language_code": "zh-CN"},
|
||||
{"name": "Chinese (Traditional)", "nllb": "zho_Hant", "language_code": "zh-TW"},
|
||||
{"name": "Standard Malay", "nllb": "zsm_Latn", "language_code": "ms"},
|
||||
{"name": "Zulu", "nllb": "zul_Latn", "language_code": "zu"},
|
||||
]
|
||||
|
||||
NAME_TO_NLLB = {lang["name"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NAME_TO_LANGUAGE_CODE = {lang["name"]: lang["language_code"] for lang in LANGUAGES}
|
||||
LANGUAGE_CODE_TO_NLLB = {lang["language_code"]: lang["nllb"] for lang in LANGUAGES}
|
||||
NLLB_TO_LANGUAGE_CODE = {lang["nllb"]: lang["language_code"] for lang in LANGUAGES}
|
||||
LANGUAGE_CODE_TO_NAME = {lang["language_code"]: lang["name"] for lang in LANGUAGES}
|
||||
NLLB_TO_NAME = {lang["nllb"]: lang["name"] for lang in LANGUAGES}
|
||||
|
||||
|
||||
def get_nllb_code(language_code_code):
|
||||
return LANGUAGE_CODE_TO_NLLB.get(language_code_code, None)
|
||||
|
||||
|
||||
def get_language_code_code(nllb_code):
|
||||
return NLLB_TO_LANGUAGE_CODE.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_name_by_language_code(language_code_code):
|
||||
return LANGUAGE_CODE_TO_NAME.get(language_code_code)
|
||||
|
||||
|
||||
def get_language_name_by_nllb(nllb_code):
|
||||
return NLLB_TO_NAME.get(nllb_code)
|
||||
|
||||
|
||||
def get_language_info(identifier, identifier_type="auto"):
|
||||
if identifier_type == "auto":
|
||||
for lang in LANGUAGES:
|
||||
if (lang["name"].lower() == identifier.lower() or
|
||||
lang["nllb"] == identifier or
|
||||
lang["language_code"] == identifier):
|
||||
return lang
|
||||
elif identifier_type == "name":
|
||||
for lang in LANGUAGES:
|
||||
if lang["name"].lower() == identifier.lower():
|
||||
return lang
|
||||
elif identifier_type == "nllb":
|
||||
for lang in LANGUAGES:
|
||||
if lang["nllb"] == identifier:
|
||||
return lang
|
||||
elif identifier_type == "language_code":
|
||||
for lang in LANGUAGES:
|
||||
if lang["language_code"] == identifier:
|
||||
return lang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def list_all_languages():
|
||||
return [lang["name"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_nllb_codes():
|
||||
return [lang["nllb"] for lang in LANGUAGES]
|
||||
|
||||
|
||||
def list_all_language_code_codes():
|
||||
return [lang["language_code"] for lang in LANGUAGES]
|
||||
@@ -1,171 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
|
||||
# sentence is not finished.
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
device: str
|
||||
tokenizer: dict = field(default_factory=dict)
|
||||
backend_type: str = 'ctranslate2'
|
||||
nllb_size: str = '600M'
|
||||
|
||||
def get_tokenizer(self, input_lang):
|
||||
if not self.tokenizer.get(input_lang, False):
|
||||
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
||||
f"facebook/nllb-200-distilled-{self.nllb_size}",
|
||||
src_lang=input_lang,
|
||||
clean_up_tokenization_spaces=True
|
||||
)
|
||||
return self.tokenizer[input_lang]
|
||||
|
||||
|
||||
def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if nllb_backend=='ctranslate2':
|
||||
model = f'nllb-200-distilled-{nllb_size}-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + model,local_dir=model)
|
||||
translator = ctranslate2.Translator(model,device=device)
|
||||
elif nllb_backend=='transformers':
|
||||
model = f"facebook/nllb-200-distilled-{nllb_size}"
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(model)
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(model, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
translation_model = TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer,
|
||||
backend_type=nllb_backend,
|
||||
device = device,
|
||||
nllb_size = nllb_size
|
||||
)
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
translation_model.get_tokenizer(src_lang)
|
||||
return translation_model
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.input_buffer = []
|
||||
self.len_processed_buffer = 0
|
||||
self.translation_remaining = Translation()
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
|
||||
def compute_common_prefix(self, results):
|
||||
#we dont want want to prune the result for the moment.
|
||||
if not self.input_buffer:
|
||||
self.input_buffer = results
|
||||
else:
|
||||
for i in range(min(len(self.input_buffer), len(results))):
|
||||
if self.input_buffer[i] != results[i]:
|
||||
self.commited.extend(self.input_buffer[:i])
|
||||
self.input_buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang, output_lang):
|
||||
if not input:
|
||||
return ""
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
else:
|
||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
return result
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
if self.input_languages[0] == 'auto':
|
||||
input_lang = tokens[0].detected_language
|
||||
else:
|
||||
input_lang = self.input_languages[0]
|
||||
|
||||
translated_text = self.translate(text,
|
||||
input_lang,
|
||||
self.output_languages[0]
|
||||
)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.input_buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
if len(self.input_buffer) < self.len_processed_buffer + 3: #nothing new to process
|
||||
return self.validated + [self.translation_remaining]
|
||||
while i < len(self.input_buffer):
|
||||
if self.input_buffer[i].is_punctuation():
|
||||
translation_sentence = self.translate_tokens(self.input_buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.input_buffer = self.input_buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
self.translation_remaining = self.translate_tokens(self.input_buffer)
|
||||
self.len_processed_buffer = len(self.input_buffer)
|
||||
return self.validated, [self.translation_remaining]
|
||||
|
||||
def insert_silence(self, silence_duration: float):
|
||||
if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
|
||||
self.input_buffer = []
|
||||
self.validated += [self.translation_remaining]
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
input_lang = "en"
|
||||
|
||||
|
||||
test_string = """
|
||||
Transcription technology has improved so much in the past few years. Have you noticed how accurate real-time speech-to-text is now?
|
||||
"""
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang], nllb_backend='transformers')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str, input_lang = input_lang, output_lang = output_lang)
|
||||
print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
Reference in New Issue
Block a user