translator takes all the tokens from the queue

This commit is contained in:
Quentin Fuxa
2025-09-09 19:55:39 +02:00
parent da8726b2cb
commit add7ea07ee
6 changed files with 105 additions and 12 deletions

View File

@@ -16,6 +16,17 @@ logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
async def get_all_from_queue(queue):
items = []
try:
while True:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
pass
return items
class AudioProcessor:
"""
Processes audio streams for transcription and diarization.
@@ -265,6 +276,8 @@ class AudioProcessor:
if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(SENTINEL)
logger.debug("Sentinel put into diarization_queue.")
if self.args.target_language and self.translation_queue:
await self.translation_queue.put(SENTINEL)
async def transcription_processor(self):
@@ -308,9 +321,6 @@ class AudioProcessor:
cumulative_pcm_duration_stream_time += duration_this_chunk
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
new_tokens, current_audio_processed_upto = self.online.process_iter()
@@ -338,6 +348,11 @@ class AudioProcessor:
await self.update_transcription(
new_tokens, buffer_text, new_end_buffer, self.sep
)
if new_tokens and self.args.target_language and self.translation_queue:
for token in new_tokens:
await self.translation_queue.put(token)
self.transcription_queue.task_done()
except Exception as e:
@@ -398,9 +413,44 @@ class AudioProcessor:
# in the future we want to have different languages for each speaker etc, so it will be more complex.
while True:
try:
item = await self.translation_queue.get()
token = await self.translation_queue.get() #block until at least 1 token
if token is SENTINEL:
logger.debug("Translation processor received sentinel. Finishing.")
self.translation_queue.task_done()
break
# get all the available tokens for translation. The more words, the more precise
tokens_to_process = [token]
additional_tokens = await get_all_from_queue(self.translation_queue)
sentinel_found = False
for additional_token in additional_tokens:
if additional_token is SENTINEL:
sentinel_found = True
break
tokens_to_process.append(additional_token)
if tokens_to_process:
online_translation.insert_tokens(tokens_to_process)
translations = online_translation.process()
print(translations)
self.translation_queue.task_done()
for _ in additional_tokens:
self.translation_queue.task_done()
if sentinel_found:
logger.debug("Translation processor received sentinel in batch. Finishing.")
break
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'token' in locals() and token is not SENTINEL:
self.translation_queue.task_done()
if 'additional_tokens' in locals():
for _ in additional_tokens:
self.translation_queue.task_done()
logger.info("Translation processor task finished.")
async def results_formatter(self):
"""Format processing results for output."""
@@ -546,8 +596,10 @@ class AudioProcessor:
self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task)
if self.args.target_language and self.args.language != 'auto':
if self.args.target_language and self.args.lan != 'auto':
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task)
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)

View File

@@ -136,11 +136,11 @@ class TranscriptionEngine:
self.translation_model = None
if self.args.target_language:
if self.args.language == 'auto':
if self.args.lan == 'auto':
raise Exception('Translation cannot be set with language auto')
else:
from whisperlivekit.translation.translation import load_model
self.translation_model = load_model([self.args.language]) #in the future we want to handle different languages for different speakers
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
@@ -181,4 +181,4 @@ def online_translation_factory(args, translation_model):
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from whisperlivekit.translation.translation import OnlineTranslation
online = OnlineTranslation(translation_model, [args.language], [args.target_language])
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -6,7 +6,7 @@ from whisperlivekit.remove_silences import handle_silences
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
CHECK_AROUND = 4
def format_time(seconds: float) -> str:
@@ -59,6 +59,7 @@ def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
translated_tokens = state["translated_tokens"] # Here we will attribute the speakers only based on the timestamps of the segments
buffer_transcription = state["buffer_transcription"]
buffer_diarization = state["buffer_diarization"]
end_attributed_speaker = state["end_attributed_speaker"]

View File

@@ -174,7 +174,6 @@ class PaddedAlignAttWhisper:
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
def remove_hooks(self):
print('remove hook')
for hook in self.l_hooks:
hook.remove()

View File

@@ -31,6 +31,10 @@ class SpeakerSegment(TimedText):
"""
pass
@dataclass
class Translation(TimedText):
pass
@dataclass
class Silence():
duration: float

View File

@@ -4,15 +4,18 @@ import transformers
from dataclasses import dataclass
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from timed_objects import Translation
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@dataclass
class TranslationModel():
translator: ctranslate2.Translator
tokenizer: dict()
tokenizer: dict
def load_model(src_langs):
MODEL = 'nllb-200-distilled-600M-ctranslate2'
@@ -38,7 +41,8 @@ def translate(input, translation_model, tgt_lang):
class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
self.buffer = []
self.commited = []
self.validated = []
self.translation_pending_validation = ''
self.translation_model = translation_model
self.input_languages = input_languages
self.output_languages = output_languages
@@ -68,6 +72,39 @@ class OnlineTranslation:
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
return results
def translate_tokens(self, tokens):
if tokens:
text = ' '.join([token.text for token in tokens])
start = tokens[0].start
end = tokens[-1].end
translated_text = self.translate(text)
translation = Translation(
text=translated_text,
start=start,
end=end,
)
return translation
return None
def insert_tokens(self, tokens):
self.buffer.extend(tokens)
pass
def process(self):
i = 0
while i < len(self.buffer):
if self.buffer[i].text in PUNCTUATION_MARKS:
translation_sentence = self.translate_tokens(self.buffer[:i+1])
self.validated.append(translation_sentence)
self.buffer = self.buffer[i+1:]
i = 0
else:
i+=1
translation_remaining = self.translate_tokens(self.buffer)
return self.validated + [translation_remaining]
if __name__ == '__main__':
output_lang = 'fr'