mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
nllb backend can be transformers, and model size can be 1.3B
This commit is contained in:
@@ -200,7 +200,8 @@ An important list of parameters can be changed. But what *should* you change?
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | [NOT FUNCTIONNAL YET] transformer or ctranslate2 | `ctranslate2` |
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
> For diarization using Diart, you need access to pyannote.audio models:
|
||||
> 1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
||||
|
||||
@@ -70,7 +70,8 @@ class TranscriptionEngine:
|
||||
"embedding_model": "pyannote/embedding",
|
||||
|
||||
# translation params:
|
||||
"nllb_backend": "ctranslate2"
|
||||
"nllb_backend": "ctranslate2",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
|
||||
config_dict = {**defaults, **kwargs}
|
||||
@@ -148,8 +149,7 @@ class TranscriptionEngine:
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend) #in the future we want to handle different languages for different speakers
|
||||
|
||||
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
|
||||
@@ -291,7 +291,14 @@ def parse_args():
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="ctranslate2",
|
||||
help="transformer or ctranslate2",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
@@ -20,39 +21,29 @@ MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider t
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict
|
||||
device: str
|
||||
backend_type: str = 'ctranslate2'
|
||||
|
||||
def load_model(src_langs, backend='ctranslate2'):
|
||||
def load_model(src_langs, backend='ctranslate2', model_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL = f'nllb-200-distilled-{model_size}-ctranslate2'
|
||||
if backend=='ctranslate2':
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
MODEL_GUY = 'entai2965'
|
||||
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
translator = ctranslate2.Translator(MODEL,device=device)
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
elif backend=='transformers':
|
||||
raise Exception('not implemented yet')
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
return TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer
|
||||
tokenizer=tokenizer,
|
||||
backend_type=backend,
|
||||
device = device
|
||||
)
|
||||
|
||||
def translate(input, translation_model, tgt_lang, src_lang="en"):
|
||||
# Get the specific tokenizer for the source language
|
||||
tokenizer = translation_model.tokenizer[src_lang]
|
||||
|
||||
# Convert input to tokens
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(input))
|
||||
|
||||
# Translate with target language prefix
|
||||
target_prefix = [tgt_lang]
|
||||
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
|
||||
# Get translated tokens and decode
|
||||
target = results[0].hypotheses[0][1:]
|
||||
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
@@ -83,12 +74,19 @@ class OnlineTranslation:
|
||||
output_lang = self.output_languages[0]
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) #we can use return_attention=True to try to optimize the stuff.
|
||||
target = results[0].hypotheses[0][1:]
|
||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||||
return results
|
||||
|
||||
tokenizer = self.translation_model.tokenizer[input_lang]
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
|
||||
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
|
||||
else:
|
||||
translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
|
||||
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
return result
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
@@ -103,7 +101,6 @@ class OnlineTranslation:
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
@@ -141,16 +138,13 @@ if __name__ == '__main__':
|
||||
test = test_string.split(' ')
|
||||
step = len(test) // 3
|
||||
|
||||
shared_model = load_model([input_lang])
|
||||
shared_model = load_model([input_lang], backend='ctranslate2')
|
||||
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
|
||||
|
||||
|
||||
beg_inference = time.time()
|
||||
for id in range(5):
|
||||
val = test[id*step : (id+1)*step]
|
||||
val_str = ' '.join(val)
|
||||
result = online_translation.translate(val_str)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
# print(result)
|
||||
print('inference time:', time.time() - beg_inference)
|
||||
Reference in New Issue
Block a user