From 2963e8a757ed9ce9f5393896f9a2aa86431777b5 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 9 Sep 2025 21:45:00 +0200 Subject: [PATCH] translate when at least 3 new tokens --- README.md | 2 +- whisperlivekit/timed_objects.py | 4 ++-- whisperlivekit/translation/translation.py | 11 ++++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index aca95a4..8dffe79 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ The rest I don't recommend. But below are your options. | `--model` | Whisper model size. | `small` | | `--language` | Source language code or `auto` | `auto` | | `--task` | Set to `translate` to translate to english | `transcribe` | -| `--target-language` | [NOT FUNCTIONAL YET] | `None` | +| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` | | `--backend` | Processing backend | `simulstreaming` | | `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` | | `--no-vac` | Disable Voice Activity Controller | `False` | diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index ab4045a..09545b5 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -9,8 +9,8 @@ def format_time(seconds: float) -> str: @dataclass class TimedText: - start: Optional[float] - end: Optional[float] + start: Optional[float] = 0 + end: Optional[float] = 0 text: Optional[str] = '' speaker: Optional[int] = -1 probability: Optional[float] = None diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index 7a4c734..822e0ca 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -41,6 +41,8 @@ def translate(input, translation_model, tgt_lang): class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): self.buffer = [] + self.len_processed_buffer = 0 + self.translation_remaining = Translation() self.validated = [] self.translation_pending_validation = '' self.translation_model = translation_model @@ -48,6 +50,7 @@ class OnlineTranslation: self.output_languages = output_languages def compute_common_prefix(self, results): + #we dont want want to prune the result for the moment. if not self.buffer: self.buffer = results else: @@ -63,7 +66,6 @@ class OnlineTranslation: input_lang = self.input_languages[0] if output_lang is None: output_lang = self.output_languages[0] - nllb_input_lang = get_nllb_code(input_lang) nllb_output_lang = get_nllb_code(output_lang) source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input)) @@ -94,6 +96,8 @@ class OnlineTranslation: def process(self): i = 0 + if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process + return self.validated + [self.translation_remaining] while i < len(self.buffer): if self.buffer[i].text in PUNCTUATION_MARKS: translation_sentence = self.translate_tokens(self.buffer[:i+1]) @@ -102,8 +106,9 @@ class OnlineTranslation: i = 0 else: i+=1 - translation_remaining = self.translate_tokens(self.buffer) - return self.validated + [translation_remaining] + self.translation_remaining = self.translate_tokens(self.buffer) + self.len_processed_buffer = len(self.buffer) + return self.validated + [self.translation_remaining] if __name__ == '__main__':