diff --git a/DEV_NOTES.md b/DEV_NOTES.md index c41016f..f9c3c4a 100644 --- a/DEV_NOTES.md +++ b/DEV_NOTES.md @@ -18,8 +18,29 @@ Decoder weights: 59110771 bytes Encoder weights: 15268874 bytes +# 2. Translation: Faster model for each system -# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm +## Benchmark Results + +Testing on MacBook M3 with NLLB-200-distilled-600M model: + +### Standard Transformers vs CTranslate2 + +| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup | +|-----------|-------------------------|---------------------------|---------| +| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x | +| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x | +| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x | +| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x | +| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x | + +**Results:** +- Total Standard time: 4.1068s +- Total CTranslate2 time: 8.5476s +- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation. + + +# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions. @@ -67,4 +88,4 @@ ELSE: AS_2 ← B to finish -``` \ No newline at end of file +``` diff --git a/README.md b/README.md index 2d0cb83..2c267a7 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,10 @@ An important list of parameters can be changed. But what *should* you change? | `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` | +| Translation options | Description | Default | +|-----------|-------------|---------| +| `--nllb-backend` | [NOT FUNCTIONNAL YET] transformer or ctranslate2 | `ctranslate2` | + > For diarization using Diart, you need access to pyannote.audio models: > 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model > 2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index fd290d5..7f2eaf4 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -43,10 +43,12 @@ class TranscriptionEngine: "transcription": True, "vad": True, "pcm_input": False, + # whisperstreaming params: "buffer_trimming": "segment", "confidence_validation": False, "buffer_trimming_sec": 15, + # simulstreaming params: "disable_fast_encoder": False, "frame_threshold": 25, @@ -61,10 +63,14 @@ class TranscriptionEngine: "max_context_tokens": None, "model_path": './base.pt', "diarization_backend": "sortformer", + # diarization params: "disable_punctuation_split" : False, "segmentation_model": "pyannote/segmentation-3.0", - "embedding_model": "pyannote/embedding", + "embedding_model": "pyannote/embedding", + + # translation params: + "nllb_backend": "ctranslate2" } config_dict = {**defaults, **kwargs} @@ -142,7 +148,7 @@ class TranscriptionEngine: raise Exception('Translation cannot be set with language auto') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers + self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 3ef74bf..c73e6d4 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -287,6 +287,13 @@ def parse_args(): help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).", ) + simulstreaming_group.add_argument( + "--nllb-backend", + type=str, + default="ctranslate2", + help="transformer or ctranslate2", + ) + args = parser.parse_args() args.transcription = not args.no_transcription diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index 88bb5e2..7923243 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -21,26 +21,37 @@ class TranslationModel(): translator: ctranslate2.Translator tokenizer: dict -def load_model(src_langs): - MODEL = 'nllb-200-distilled-600M-ctranslate2' - MODEL_GUY = 'entai2965' - huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) - device = "cuda" if torch.cuda.is_available() else "cpu" - translator = ctranslate2.Translator(MODEL,device=device) - tokenizer = dict() - for src_lang in src_langs: - tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) +def load_model(src_langs, backend='ctranslate2'): + if backend=='ctranslate2': + MODEL = 'nllb-200-distilled-600M-ctranslate2' + MODEL_GUY = 'entai2965' + huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) + device = "cuda" if torch.cuda.is_available() else "cpu" + translator = ctranslate2.Translator(MODEL,device=device) + tokenizer = dict() + for src_lang in src_langs: + tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + elif backend=='transformers': + raise Exception('not implemented yet') return TranslationModel( translator=translator, tokenizer=tokenizer ) -def translate(input, translation_model, tgt_lang): - source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input)) +def translate(input, translation_model, tgt_lang, src_lang="en"): + # Get the specific tokenizer for the source language + tokenizer = translation_model.tokenizer[src_lang] + + # Convert input to tokens + source = tokenizer.convert_ids_to_tokens(tokenizer.encode(input)) + + # Translate with target language prefix target_prefix = [tgt_lang] results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix]) + + # Get translated tokens and decode target = results[0].hypotheses[0][1:] - return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target)) + return tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): @@ -142,4 +153,4 @@ if __name__ == '__main__': - # print(result) \ No newline at end of file + # print(result)