translation asyncio task

This commit is contained in:
Quentin Fuxa
2025-09-07 16:30:00 +02:00
parent b6164aa59b
commit f661f21675
6 changed files with 111 additions and 24 deletions

3
.gitignore vendored
View File

@@ -137,4 +137,5 @@ run_*.sh
test_*.py
launch.json
.DS_Store
test/*
test/*
nllb-200-distilled-600M-ctranslate2/*

View File

@@ -5,7 +5,7 @@ import math
import logging
import traceback
from whisperlivekit.timed_objects import ASRToken, Silence
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output, format_time
@@ -48,6 +48,7 @@ class AudioProcessor:
self.silence = False
self.silence_duration = 0.0
self.tokens = []
self.translated_tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
self.end_buffer = 0
@@ -80,23 +81,21 @@ class AudioProcessor:
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
self.translation_queue = asyncio.Queue() if self.args.target_language else None
self.pcm_buffer = bytearray()
# Task references
self.transcription_task = None
self.diarization_task = None
self.ffmpeg_reader_task = None
self.watchdog_task = None
self.all_tasks_for_cleanup = []
# Initialize transcription engine if enabled
if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer)
# Initialize diarization engine if enabled
self.online = online_factory(self.args, models.asr, models.tokenizer)
if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model)
if self.args.target_language:
self.online_translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer):
"""Convert PCM buffer in s16le format to normalized NumPy array."""
@@ -143,6 +142,7 @@ class AudioProcessor:
return {
"tokens": self.tokens.copy(),
"translated_tokens": self.translated_tokens.copy(),
"buffer_transcription": self.buffer_transcription,
"buffer_diarization": self.buffer_diarization,
"end_buffer": self.end_buffer,
@@ -156,6 +156,7 @@ class AudioProcessor:
"""Reset all state variables to initial values."""
async with self.lock:
self.tokens = []
self.translated_tokens = []
self.buffer_transcription = self.buffer_diarization = ""
self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time()
@@ -391,6 +392,15 @@ class AudioProcessor:
self.diarization_queue.task_done()
logger.info("Diarization processor task finished.")
async def translation_processor(self, online_translation):
# the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
item = await self.translation_queue.get()
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
async def results_formatter(self):
"""Format processing results for output."""
@@ -536,6 +546,9 @@ class AudioProcessor:
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.args.target_language and self.args.language != 'auto':
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)

View File

@@ -140,7 +140,7 @@ class TranscriptionEngine:
raise Exception('Translation cannot be set with language auto')
else:
from whisperlivekit.translation.translation import load_model
self.translation_model = load_model()
self.translation_model = load_model([self.args.language]) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
@@ -168,11 +168,17 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
def online_diarization_factory(args, diarization_backend):
if args.diarization_backend == "diart":
online = diarization_backend
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommanded
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
if args.diarization_backend == "sortformer":
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
online = SortformerDiarizationOnline(shared_model=diarization_backend)
return online
def online_translation_factory(args, translation_model):
#should be at speaker level in the future:
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from whisperlivekit.translation.translation import OnlineTranslation
online = OnlineTranslation(translation_model, [args.language], [args.target_language])

View File

@@ -3,40 +3,78 @@ import torch
import transformers
from dataclasses import dataclass
import huggingface_hub
from .mapping_languages import get_nllb_code
from whisperlivekit.translation.mapping_languages import get_nllb_code
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
tokenizer: transformers.AutoTokenizer
tokenizer: dict()
def load_model(src_lang):
def load_model(src_langs):
MODEL = 'nllb-200-distilled-600M-ctranslate2'
MODEL_GUY = 'entai2965'
huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"
translator = ctranslate2.Translator(MODEL,device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
tokenizer = dict()
for src_lang in src_langs:
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
return TranslationModel(
translator=translator,
tokenizer=tokenizer
)
def translate(input, translation_model, tgt_lang):
if not input:
return ""
source = translation_model.tokenizer.convert_ids_to_tokens(translation_model.tokenizer.encode(input))
target_prefix = [tgt_lang]
results = translation_model.translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
return translation_model.tokenizer.decode(translation_model.tokenizer.convert_tokens_to_ids(target))
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.commited = []
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
def compute_common_prefix(self, results):
if not self.buffer:
self.buffer = results
else:
for i in range(min(len(self.buffer), len(results))):
if self.buffer[i] != results[i]:
self.commited.extend(self.buffer[:i])
self.buffer = results[i:]
def translate(self, input, input_lang=None, output_lang=None):
if not input:
return ""
if input_lang is None:
input_lang = self.input_languages[0]
if output_lang is None:
output_lang = self.output_languages[0]
nllb_input_lang = get_nllb_code(input_lang)
nllb_output_lang = get_nllb_code(output_lang)
source = self.translation_model.tokenizer[input_lang].convert_ids_to_tokens(self.translation_model.tokenizer[input_lang].encode(input))
results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
target = results[0].hypotheses[0][1:]
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
return results
if __name__ == '__main__':
tgt_lang = 'fr'
src_lang = "en"
nllb_tgt_lang = get_nllb_code(tgt_lang)
nllb_src_lang = get_nllb_code(src_lang)
translation_model = load_model(nllb_src_lang)
result = translate('Hello world', translation_model=translation_model, tgt_lang=nllb_tgt_lang)
output_lang = 'fr'
input_lang = "en"
shared_model = load_model([input_lang])
online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
result = online_translation.translate('Hello world')
print(result)

View File

@@ -368,6 +368,27 @@ label {
color: var(--label-trans-text);
}
.label_translation {
background-color: var(--chip-bg);
border-radius: 10px;
padding: 4px 8px;
margin-top: 4px;
font-size: 14px;
color: var(--text);
display: flex;
align-items: flex-start;
gap: 4px;
}
.label_translation img {
margin-top: 2px;
}
.label_translation img {
width: 12px;
height: 12px;
}
#timeInfo {
color: var(--muted);
margin-left: 10px;
@@ -417,6 +438,7 @@ label {
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
display: none;
}
.loading {

View File

@@ -332,6 +332,13 @@ function renderLinesWithBuffer(
}
let currentLineText = item.text || "";
if (item.translation) {
currentLineText += `<div class="label_translation">
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" />
<span>${item.translation}</span>
</div>`;
}
if (idx === lines.length - 1) {
if (!isFinalizing && item.speaker !== -2) {