diff --git a/architecture.png b/architecture.png index b9aa73f..64d5f9c 100644 Binary files a/architecture.png and b/architecture.png differ diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index ea75b47..524fac5 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -4,7 +4,7 @@ from time import time, sleep import math import logging import traceback -from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript +from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.results_formater import format_output @@ -67,7 +67,9 @@ class AudioProcessor: self.last_response_content = FrontData() self.last_detected_speaker = None self.speaker_languages = {} - + self.cumulative_pcm_len = 0 + self.diarization_before_transcription = False + # Models and processing self.asr = models.asr self.tokenizer = models.tokenizer @@ -100,13 +102,14 @@ class AudioProcessor: self.diarization_task = None self.watchdog_task = None self.all_tasks_for_cleanup = [] + self.online_translation = None if self.args.transcription: self.online = online_factory(self.args, models.asr, models.tokenizer) self.sep = self.online.asr.sep if self.args.diarization: self.diarization = online_diarization_factory(self.args, models.diarization_model) - if self.args.target_language: + if models.translation_model: self.online_translation = online_translation_factory(self.args, models.translation_model) def convert_pcm_to_float(self, pcm_buffer): @@ -199,11 +202,11 @@ class AudioProcessor: await asyncio.sleep(0.2) logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.") - if self.args.transcription and self.transcription_queue: + if not self.diarization_before_transcription and self.transcription_queue: await self.transcription_queue.put(SENTINEL) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(SENTINEL) - if self.args.target_language and self.translation_queue: + if self.online_translation: await self.translation_queue.put(SENTINEL) async def transcription_processor(self): @@ -217,11 +220,6 @@ class AudioProcessor: logger.debug("Transcription processor received sentinel. Finishing.") self.transcription_queue.task_done() break - - if not self.online: - logger.warning("Transcription processor: self.online not initialized.") - self.transcription_queue.task_done() - continue asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) @@ -234,12 +232,12 @@ class AudioProcessor: cumulative_pcm_duration_stream_time += item.duration self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0) continue - logger.info(asr_processing_logs) - - if isinstance(item, np.ndarray): + elif isinstance(item, ChangeSpeaker): + self.online.new_speaker(item) + elif isinstance(item, np.ndarray): pcm_array = item - else: - raise Exception('item should be pcm_array') + + logger.info(asr_processing_logs) duration_this_chunk = len(pcm_array) / self.sample_rate cumulative_pcm_duration_stream_time += duration_this_chunk @@ -295,8 +293,7 @@ class AudioProcessor: async def diarization_processor(self, diarization_obj): """Process audio chunks for speaker diarization.""" - buffer_diarization = Transcript() - cumulative_pcm_duration_stream_time = 0.0 + self.current_speaker = 0 while True: try: item = await self.diarization_queue.get() @@ -305,7 +302,6 @@ class AudioProcessor: self.diarization_queue.task_done() break elif type(item) is Silence: - cumulative_pcm_duration_stream_time += item.duration diarization_obj.insert_silence(item.duration) continue elif isinstance(item, np.ndarray): @@ -315,22 +311,26 @@ class AudioProcessor: # Process diarization await diarization_obj.diarize(pcm_array) - segments = diarization_obj.get_segments() - async with self.lock: - self.tokens = diarization_obj.assign_speakers_to_tokens( - self.tokens, - use_punctuation_split=self.args.punctuation_split - ) - if len(self.tokens) > 0: - self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker) - - # if last_segment is not None and last_segment.speaker != self.last_detected_speaker: - # if not self.speaker_languages.get(last_segment.speaker, None): - # self.last_detected_speaker = last_segment.speaker - # self.online.on_new_speaker(last_segment) - + if self.diarization_before_transcription: + if segments and segments[-1].speaker != self.current_speaker: + self.current_speaker = segments[-1].speaker + cut_at = int(segments[-1].start*16000 - (self.cumulative_pcm_len)) + await self.transcription_queue.put(pcm_array[cut_at:]) + await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=cut_at)) + await self.transcription_queue.put(pcm_array[:cut_at]) + else: + await self.transcription_queue.put(pcm_array) + else: + async with self.lock: + self.tokens = diarization_obj.assign_speakers_to_tokens( + self.tokens, + use_punctuation_split=self.args.punctuation_split + ) + self.cumulative_pcm_len += len(pcm_array) + if len(self.tokens) > 0: + self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker) self.diarization_queue.task_done() except Exception as e: @@ -340,7 +340,7 @@ class AudioProcessor: self.diarization_queue.task_done() logger.info("Diarization processor task finished.") - async def translation_processor(self, online_translation): + async def translation_processor(self): # the idea is to ignore diarization for the moment. We use only transcription tokens. # And the speaker is attributed given the segments used for the translation # in the future we want to have different languages for each speaker etc, so it will be more complex. @@ -352,7 +352,7 @@ class AudioProcessor: self.translation_queue.task_done() break elif type(item) is Silence: - online_translation.insert_silence(item.duration) + self.online_translation.insert_silence(item.duration) continue # get all the available tokens for translation. The more words, the more precise @@ -366,9 +366,8 @@ class AudioProcessor: break tokens_to_process.append(additional_token) if tokens_to_process: - online_translation.insert_tokens(tokens_to_process) - self.translated_segments = await asyncio.to_thread(online_translation.process) - + self.online_translation.insert_tokens(tokens_to_process) + self.translated_segments = await asyncio.to_thread(self.online_translation.process) self.translation_queue.task_done() for _ in additional_tokens: self.translation_queue.task_done() @@ -445,8 +444,8 @@ class AudioProcessor: response = FrontData( status=response_status, lines=lines, - buffer_transcription=buffer_transcription.text, - buffer_diarization=buffer_diarization, + buffer_transcription=buffer_transcription.text.strip(), + buffer_diarization=buffer_diarization.strip(), remaining_time_transcription=state.remaining_time_transcription, remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 ) @@ -505,8 +504,8 @@ class AudioProcessor: self.all_tasks_for_cleanup.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task) - if self.args.target_language and self.args.lan != 'auto': - self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation)) + if self.online_translation: + self.translation_task = asyncio.create_task(self.translation_processor()) self.all_tasks_for_cleanup.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task) @@ -628,7 +627,7 @@ class AudioProcessor: silence_buffer = Silence(duration=time() - self.start_silence) if silence_buffer: - if self.args.transcription and self.transcription_queue: + if not self.diarization_before_transcription and self.transcription_queue: await self.transcription_queue.put(silence_buffer) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(silence_buffer) @@ -636,7 +635,7 @@ class AudioProcessor: await self.translation_queue.put(silence_buffer) if not self.silence: - if self.args.transcription and self.transcription_queue: + if not self.diarization_before_transcription and self.transcription_queue: await self.transcription_queue.put(pcm_array.copy()) if self.args.diarization and self.diarization_queue: diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 578e624..b4ef8d8 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -145,8 +145,8 @@ class TranscriptionEngine: self.translation_model = None if self.args.target_language: - if self.args.lan == 'auto': - raise Exception('Translation cannot be set with language auto') + if self.args.lan == 'auto' and self.args.backend != "simulstreaming": + raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') else: from whisperlivekit.translation.translation import load_model self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index b9187ff..65ba194 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -4,7 +4,7 @@ import logging from typing import List, Tuple, Optional import logging import platform -from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment +from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker from whisperlivekit.warmup import load_file from .whisper import load_model, tokenizer from .whisper.audio import TOKENS_PER_SECOND @@ -93,14 +93,16 @@ class SimulStreamingOnlineProcessor: self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.model.insert_audio(audio_tensor) - def on_new_speaker(self, last_segment: SpeakerSegment): - self.model.on_new_speaker(last_segment) + def new_speaker(self, change_speaker: ChangeSpeaker): + self.process_iter(is_last=True) self.model.refresh_segment(complete=True) - + self.model.speaker = change_speaker.speaker + self.global_time_offset = change_speaker.start + def get_buffer(self): concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') return concat_buffer - + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """ Process accumulated audio chunks using SimulStreaming. diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 3449768..250fa5b 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -66,7 +66,7 @@ class PaddedAlignAttWhisper: self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) logger.info(f"Model dimensions: {self.model.dims}") - + self.speaker = -1 self.decode_options = DecodingOptions( language = cfg.language, without_timestamps = True, @@ -152,7 +152,7 @@ class PaddedAlignAttWhisper: self.last_attend_frame = -self.cfg.rewind_threshold self.cumulative_time_offset = 0.0 - self.second_word_timestamp = None + self.first_timestamp = None if self.cfg.max_context_tokens is None: self.max_context_tokens = self.max_text_len @@ -432,9 +432,9 @@ class PaddedAlignAttWhisper: end_encode = time() # print('Encoder duration:', end_encode-beg_encode) - if self.cfg.language == "auto" and self.detected_language is None and self.second_word_timestamp: - seconds_since_start = self.segments_len() - self.second_word_timestamp - if seconds_since_start >= 5.0: + if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp: + seconds_since_start = self.segments_len() - self.first_timestamp + if seconds_since_start >= 2.0: language_tokens, language_probs = self.lang_id(encoder_feature) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) print(f"Detected language: {top_lan} with p={p:.4f}") @@ -445,8 +445,6 @@ class PaddedAlignAttWhisper: self.init_context() self.detected_language = top_lan logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") - else: - logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s") self.trim_context() current_tokens = self._current_tokens() @@ -591,8 +589,8 @@ class PaddedAlignAttWhisper: self._clean_cache() - if len(l_absolute_timestamps) >=2 and self.second_word_timestamp is None: - self.second_word_timestamp = l_absolute_timestamps[1] + if len(l_absolute_timestamps) >=2 and self.first_timestamp is None: + self.first_timestamp = l_absolute_timestamps[0] timestamped_words = [] @@ -609,10 +607,11 @@ class PaddedAlignAttWhisper: end=current_timestamp + 0.1, text= word, probability=0.95, + speaker=self.speaker, detected_language=self.detected_language ).with_offset( self.global_time_offset ) timestamped_words.append(timestamp_entry) - return timestamped_words + return timestamped_words \ No newline at end of file diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 5706343..76c5ef8 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -160,13 +160,17 @@ class FrontData(): if self.error: _dict['error'] = self.error return _dict - + +@dataclass +class ChangeSpeaker: + speaker: int + start: int + @dataclass class State(): tokens: list translated_segments: list buffer_transcription: str - buffer_diarization: str end_buffer: float end_attributed_speaker: float remaining_time_transcription: float diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index bb144c8..90ce47e 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -3,7 +3,7 @@ import time import ctranslate2 import torch import transformers -from dataclasses import dataclass +from dataclasses import dataclass, field import huggingface_hub from whisperlivekit.translation.mapping_languages import get_nllb_code from whisperlivekit.timed_objects import Translation @@ -18,9 +18,20 @@ MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider t @dataclass class TranslationModel(): translator: ctranslate2.Translator - tokenizer: dict device: str + tokenizer: dict = field(default_factory=dict) backend_type: str = 'ctranslate2' + model_size: str = '600M' + + def get_tokenizer(self, input_lang): + if not self.tokenizer.get(input_lang, False): + self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained( + f"facebook/nllb-200-distilled-{self.model_size}", + src_lang=input_lang, + clean_up_tokenization_spaces=True + ) + return self.tokenizer[input_lang] + def load_model(src_langs, backend='ctranslate2', model_size='600M'): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -33,14 +44,20 @@ def load_model(src_langs, backend='ctranslate2', model_size='600M'): translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}") tokenizer = dict() for src_lang in src_langs: - tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) + if src_lang != 'auto': + tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) - return TranslationModel( + translation_model = TranslationModel( translator=translator, tokenizer=tokenizer, backend_type=backend, - device = device + device = device, + model_size = model_size ) + for src_lang in src_langs: + if src_lang != 'auto': + translation_model.get_tokenizer(src_lang) + return translation_model class OnlineTranslation: def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): @@ -63,16 +80,12 @@ class OnlineTranslation: self.commited.extend(self.buffer[:i]) self.buffer = results[i:] - def translate(self, input, input_lang=None, output_lang=None): + def translate(self, input, input_lang, output_lang): if not input: return "" - if input_lang is None: - input_lang = self.input_languages[0] - if output_lang is None: - output_lang = self.output_languages[0] nllb_output_lang = get_nllb_code(output_lang) - tokenizer = self.translation_model.tokenizer[input_lang] + tokenizer = self.translation_model.get_tokenizer(input_lang) tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device) if self.translation_model.backend_type == 'ctranslate2': @@ -90,7 +103,15 @@ class OnlineTranslation: text = ' '.join([token.text for token in tokens]) start = tokens[0].start end = tokens[-1].end - translated_text = self.translate(text) + if self.input_languages[0] == 'auto': + input_lang = tokens[0].detected_language + else: + input_lang = self.input_languages[0] + + translated_text = self.translate(text, + input_lang, + self.output_languages[0] + ) translation = Translation( text=translated_text, start=start,