mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
translator takes all the tokens from the queue
This commit is contained in:
@@ -16,6 +16,17 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
@@ -265,6 +276,8 @@ class AudioProcessor:
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
logger.debug("Sentinel put into diarization_queue.")
|
||||
if self.args.target_language and self.translation_queue:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
|
||||
async def transcription_processor(self):
|
||||
@@ -308,9 +321,6 @@ class AudioProcessor:
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
|
||||
|
||||
|
||||
|
||||
self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = self.online.process_iter()
|
||||
|
||||
@@ -338,6 +348,11 @@ class AudioProcessor:
|
||||
await self.update_transcription(
|
||||
new_tokens, buffer_text, new_end_buffer, self.sep
|
||||
)
|
||||
|
||||
if new_tokens and self.args.target_language and self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -398,9 +413,44 @@ class AudioProcessor:
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
item = await self.translation_queue.get()
|
||||
token = await self.translation_queue.get() #block until at least 1 token
|
||||
if token is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
tokens_to_process = [token]
|
||||
additional_tokens = await get_all_from_queue(self.translation_queue)
|
||||
|
||||
sentinel_found = False
|
||||
for additional_token in additional_tokens:
|
||||
if additional_token is SENTINEL:
|
||||
sentinel_found = True
|
||||
break
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
online_translation.insert_tokens(tokens_to_process)
|
||||
translations = online_translation.process()
|
||||
print(translations)
|
||||
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
|
||||
if sentinel_found:
|
||||
logger.debug("Translation processor received sentinel in batch. Finishing.")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'token' in locals() and token is not SENTINEL:
|
||||
self.translation_queue.task_done()
|
||||
if 'additional_tokens' in locals():
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
@@ -546,8 +596,10 @@ class AudioProcessor:
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
if self.args.target_language and self.args.language != 'auto':
|
||||
if self.args.target_language and self.args.lan != 'auto':
|
||||
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
|
||||
@@ -136,11 +136,11 @@ class TranscriptionEngine:
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.language == 'auto':
|
||||
if self.args.lan == 'auto':
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.language]) #in the future we want to handle different languages for different speakers
|
||||
self.translation_model = load_model([self.args.lan]) #in the future we want to handle different languages for different speakers
|
||||
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
@@ -181,4 +181,4 @@ def online_translation_factory(args, translation_model):
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
online = OnlineTranslation(translation_model, [args.language], [args.target_language])
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
@@ -6,7 +6,7 @@ from whisperlivekit.remove_silences import handle_silences
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?'}
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
CHECK_AROUND = 4
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
@@ -59,6 +59,7 @@ def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
|
||||
|
||||
def format_output(state, silence, current_time, diarization, debug):
|
||||
tokens = state["tokens"]
|
||||
translated_tokens = state["translated_tokens"] # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
buffer_transcription = state["buffer_transcription"]
|
||||
buffer_diarization = state["buffer_diarization"]
|
||||
end_attributed_speaker = state["end_attributed_speaker"]
|
||||
|
||||
@@ -174,7 +174,6 @@ class PaddedAlignAttWhisper:
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
def remove_hooks(self):
|
||||
print('remove hook')
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
|
||||
|
||||
@@ -31,6 +31,10 @@ class SpeakerSegment(TimedText):
|
||||
"""
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: float
|
||||
@@ -4,15 +4,18 @@ import transformers
|
||||
from dataclasses import dataclass
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from timed_objects import Translation
|
||||
|
||||
|
||||
#In diarization case, we may want to translate just one speaker, or at least start the sentences there
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict()
|
||||
tokenizer: dict
|
||||
|
||||
def load_model(src_langs):
|
||||
MODEL = 'nllb-200-distilled-600M-ctranslate2'
|
||||
@@ -38,7 +41,8 @@ def translate(input, translation_model, tgt_lang):
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
self.buffer = []
|
||||
self.commited = []
|
||||
self.validated = []
|
||||
self.translation_pending_validation = ''
|
||||
self.translation_model = translation_model
|
||||
self.input_languages = input_languages
|
||||
self.output_languages = output_languages
|
||||
@@ -68,6 +72,39 @@ class OnlineTranslation:
|
||||
results = self.translation_model.tokenizer[input_lang].decode(self.translation_model.tokenizer[input_lang].convert_tokens_to_ids(target))
|
||||
return results
|
||||
|
||||
def translate_tokens(self, tokens):
|
||||
if tokens:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
translated_text = self.translate(text)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return translation
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def insert_tokens(self, tokens):
|
||||
self.buffer.extend(tokens)
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
while i < len(self.buffer):
|
||||
if self.buffer[i].text in PUNCTUATION_MARKS:
|
||||
translation_sentence = self.translate_tokens(self.buffer[:i+1])
|
||||
self.validated.append(translation_sentence)
|
||||
self.buffer = self.buffer[i+1:]
|
||||
i = 0
|
||||
else:
|
||||
i+=1
|
||||
translation_remaining = self.translate_tokens(self.buffer)
|
||||
return self.validated + [translation_remaining]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_lang = 'fr'
|
||||
|
||||
Reference in New Issue
Block a user