diff --git a/.gitignore b/.gitignore index 7c566bf..6bb0fc4 100644 --- a/.gitignore +++ b/.gitignore @@ -137,4 +137,5 @@ run_*.sh test_*.py launch.json .DS_Store -test/* \ No newline at end of file +test/* +nllb-200-distilled-600M-ctranslate2/* \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 01d79d5..98d2557 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -5,7 +5,7 @@ import math import logging import traceback from whisperlivekit.timed_objects import ASRToken, Silence -from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory +from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.results_formater import format_output, format_time @@ -48,6 +48,7 @@ class AudioProcessor: self.silence = False self.silence_duration = 0.0 self.tokens = [] + self.translated_tokens = [] self.buffer_transcription = "" self.buffer_diarization = "" self.end_buffer = 0 @@ -80,23 +81,21 @@ class AudioProcessor: self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.diarization_queue = asyncio.Queue() if self.args.diarization else None + self.translation_queue = asyncio.Queue() if self.args.target_language else None self.pcm_buffer = bytearray() - # Task references self.transcription_task = None self.diarization_task = None self.ffmpeg_reader_task = None self.watchdog_task = None self.all_tasks_for_cleanup = [] - # Initialize transcription engine if enabled if self.args.transcription: - self.online = online_factory(self.args, models.asr, models.tokenizer) - - # Initialize diarization engine if enabled + self.online = online_factory(self.args, models.asr, models.tokenizer) if self.args.diarization: self.diarization = online_diarization_factory(self.args, models.diarization_model) - + if self.args.target_language: + self.online_translation = online_translation_factory(self.args, models.translation_model) def convert_pcm_to_float(self, pcm_buffer): """Convert PCM buffer in s16le format to normalized NumPy array.""" @@ -143,6 +142,7 @@ class AudioProcessor: return { "tokens": self.tokens.copy(), + "translated_tokens": self.translated_tokens.copy(), "buffer_transcription": self.buffer_transcription, "buffer_diarization": self.buffer_diarization, "end_buffer": self.end_buffer, @@ -156,6 +156,7 @@ class AudioProcessor: """Reset all state variables to initial values.""" async with self.lock: self.tokens = [] + self.translated_tokens = [] self.buffer_transcription = self.buffer_diarization = "" self.end_buffer = self.end_attributed_speaker = 0 self.beg_loop = time() @@ -391,6 +392,15 @@ class AudioProcessor: self.diarization_queue.task_done() logger.info("Diarization processor task finished.") + async def translation_processor(self, online_translation): + # the idea is to ignore diarization for the moment. We use only transcription tokens. + # And the speaker is attributed given the segments used for the translation + # in the future we want to have different languages for each speaker etc, so it will be more complex. + while True: + try: + item = await self.translation_queue.get() + except Exception as e: + logger.warning(f"Exception in translation_processor: {e}") async def results_formatter(self): """Format processing results for output.""" @@ -536,6 +546,9 @@ class AudioProcessor: self.all_tasks_for_cleanup.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task) + if self.args.target_language and self.args.language != 'auto': + self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation)) + self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader()) self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task) processing_tasks_for_watchdog.append(self.ffmpeg_reader_task) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index bc07aa8..b3bbbe7 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -140,7 +140,7 @@ class TranscriptionEngine: raise Exception('Translation cannot be set with language auto') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model() + self.translation_model = load_model([self.args.language]) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True @@ -168,11 +168,17 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr): def online_diarization_factory(args, diarization_backend): if args.diarization_backend == "diart": online = diarization_backend - # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded + # Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended if args.diarization_backend == "sortformer": from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline online = SortformerDiarizationOnline(shared_model=diarization_backend) return online - \ No newline at end of file + +def online_translation_factory(args, translation_model): + #should be at speaker level in the future: + #one shared nllb model for all speaker + #one tokenizer per speaker/language + from whisperlivekit.translation.translation import OnlineTranslation + online = OnlineTranslation(translation_model, [args.language], [args.target_language]) \ No newline at end of file diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index edea57b..edc6677 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -3,40 +3,78 @@ import torch import transformers from dataclasses import dataclass import huggingface_hub -from .mapping_languages import get_nllb_code +from whisperlivekit.translation.mapping_languages import get_nllb_code + +#In diarization case, we may want to translate just one speaker, or at least start the sentences there + + @dataclass class TranslationModel(): translator: ctranslate2.Translator - tokenizer: transformers.AutoTokenizer + tokenizer: dict() -def load_model(src_lang): +def load_model(src_langs): MODEL = 'nllb-200-distilled-600M-ctranslate2' MODEL_GUY = 'entai2965' huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" translator = ctranslate2.Translator(MODEL,device=device) - tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + tokenizer = dict() + for src_lang in src_langs: + tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) return TranslationModel( translator=translator, tokenizer=tokenizer ) def translate(input, translation_model, tgt_lang): - if not input: - return "" source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input)) target_prefix = [tgt_lang] results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix]) target = results[0].hypotheses[0][1:] return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target)) +class OnlineTranslation: + def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): + self.buffer = [] + self.commited = [] + self.translation_model = translation_model + self.input_languages = input_languages + self.output_languages = output_languages + + def compute_common_prefix(self, results): + if not self.buffer: + self.buffer = results + else: + for i in range(min(len(self.buffer), len(results))): + if self.buffer[i] != results[i]: + self.commited.extend(self.buffer[:i]) + self.buffer = results[i:] + + def translate(self, input, input_lang=None, output_lang=None): + if not input: + return "" + if input_lang is None: + input_lang = self.input_languages[0] + if output_lang is None: + output_lang = self.output_languages[0] + nllb_input_lang = get_nllb_code(input_lang) + nllb_output_lang = get_nllb_code(output_lang) + + source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input)) + results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]]) + target = results[0].hypotheses[0][1:] + results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target)) + return results + if __name__ == '__main__': - tgt_lang = 'fr' - src_lang = "en" - nllb_tgt_lang = get_nllb_code(tgt_lang) - nllb_src_lang = get_nllb_code(src_lang) - translation_model = load_model(nllb_src_lang) - result = translate('Hello world', translation_model=translation_model, tgt_lang=nllb_tgt_lang) + output_lang = 'fr' + input_lang = "en" + + shared_model = load_model([input_lang]) + online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) + + result = online_translation.translate('Hello world') print(result) \ No newline at end of file diff --git a/whisperlivekit/web/live_transcription.css b/whisperlivekit/web/live_transcription.css index 118b4f7..422d156 100644 --- a/whisperlivekit/web/live_transcription.css +++ b/whisperlivekit/web/live_transcription.css @@ -368,6 +368,27 @@ label { color: var(--label-trans-text); } +.label_translation { + background-color: var(--chip-bg); + border-radius: 10px; + padding: 4px 8px; + margin-top: 4px; + font-size: 14px; + color: var(--text); + display: flex; + align-items: flex-start; + gap: 4px; +} + +.label_translation img { + margin-top: 2px; +} + +.label_translation img { + width: 12px; + height: 12px; +} + #timeInfo { color: var(--muted); margin-left: 10px; @@ -417,6 +438,7 @@ label { font-size: 13px; border-radius: 30px; padding: 2px 10px; + display: none; } .loading { diff --git a/whisperlivekit/web/live_transcription.js b/whisperlivekit/web/live_transcription.js index 5d57e51..af804b5 100644 --- a/whisperlivekit/web/live_transcription.js +++ b/whisperlivekit/web/live_transcription.js @@ -332,6 +332,13 @@ function renderLinesWithBuffer( } let currentLineText = item.text || ""; + + if (item.translation) { + currentLineText += `