translate when at least 3 new tokens

This commit is contained in:
Quentin Fuxa
2025-09-09 21:45:00 +02:00
parent cb2d4ea88a
commit 2963e8a757
3 changed files with 11 additions and 6 deletions

View File

@@ -151,7 +151,7 @@ The rest I don't recommend. But below are your options.
| `--model` | Whisper model size. | `small` |
| `--language` | Source language code or `auto` | `auto` |
| `--task` | Set to `translate` to translate to english | `transcribe` |
| `--target-language` | [NOT FUNCTIONAL YET] | `None` |
| `--target-language` | [BETA] Translation language target. Ex: `fr` | `None` |
| `--backend` | Processing backend | `simulstreaming` |
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
| `--no-vac` | Disable Voice Activity Controller | `False` |

View File

@@ -9,8 +9,8 @@ def format_time(seconds: float) -> str:
@dataclass
class TimedText:
start: Optional[float]
end: Optional[float]
start: Optional[float] = 0
end: Optional[float] = 0
text: Optional[str] = ''
speaker: Optional[int] = -1
probability: Optional[float] = None

View File

@@ -41,6 +41,8 @@ def translate(input, translation_model, tgt_lang):
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.len_processed_buffer = 0
self.translation_remaining = Translation()
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
@@ -48,6 +50,7 @@ class OnlineTranslation:
self.output_languages = output_languages
def compute_common_prefix(self, results):
#we dont want want to prune the result for the moment.
if not self.buffer:
self.buffer = results
else:
@@ -63,7 +66,6 @@ class OnlineTranslation:
input_lang = self.input_languages[0]
if output_lang is None:
output_lang = self.output_languages[0]
nllb_input_lang = get_nllb_code(input_lang)
nllb_output_lang = get_nllb_code(output_lang)
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
@@ -94,6 +96,8 @@ class OnlineTranslation:
def process(self):
i = 0
if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
return self.validated + [self.translation_remaining]
while i < len(self.buffer):
if self.buffer[i].text in PUNCTUATION_MARKS:
translation_sentence = self.translate_tokens(self.buffer[:i+1])
@@ -102,8 +106,9 @@ class OnlineTranslation:
i = 0
else:
i+=1
translation_remaining = self.translate_tokens(self.buffer)
return self.validated + [translation_remaining]
self.translation_remaining = self.translate_tokens(self.buffer)
self.len_processed_buffer = len(self.buffer)
return self.validated + [self.translation_remaining]
if __name__ == '__main__':