mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
translation compatible with auto and detected language
This commit is contained in:
BIN
architecture.png
BIN
architecture.png
Binary file not shown.
|
Before Width: | Height: | Size: 368 KiB After Width: | Height: | Size: 390 KiB |
@@ -4,7 +4,7 @@ from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
@@ -67,7 +67,9 @@ class AudioProcessor:
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
|
||||
self.cumulative_pcm_len = 0
|
||||
self.diarization_before_transcription = False
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.tokenizer = models.tokenizer
|
||||
@@ -100,13 +102,14 @@ class AudioProcessor:
|
||||
self.diarization_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
self.online_translation = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
||||
self.sep = self.online.asr.sep
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
if self.args.target_language:
|
||||
if models.translation_model:
|
||||
self.online_translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
@@ -199,11 +202,11 @@ class AudioProcessor:
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.args.target_language and self.translation_queue:
|
||||
if self.online_translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self):
|
||||
@@ -217,11 +220,6 @@ class AudioProcessor:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
self.transcription_queue.task_done()
|
||||
break
|
||||
|
||||
if not self.online:
|
||||
logger.warning("Transcription processor: self.online not initialized.")
|
||||
self.transcription_queue.task_done()
|
||||
continue
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
@@ -234,12 +232,12 @@ class AudioProcessor:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
|
||||
continue
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
if isinstance(item, np.ndarray):
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.online.new_speaker(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
else:
|
||||
raise Exception('item should be pcm_array')
|
||||
|
||||
logger.info(asr_processing_logs)
|
||||
|
||||
duration_this_chunk = len(pcm_array) / self.sample_rate
|
||||
cumulative_pcm_duration_stream_time += duration_this_chunk
|
||||
@@ -295,8 +293,7 @@ class AudioProcessor:
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = Transcript()
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
self.current_speaker = 0
|
||||
while True:
|
||||
try:
|
||||
item = await self.diarization_queue.get()
|
||||
@@ -305,7 +302,6 @@ class AudioProcessor:
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
@@ -315,22 +311,26 @@ class AudioProcessor:
|
||||
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
segments = diarization_obj.get_segments()
|
||||
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
|
||||
# if last_segment is not None and last_segment.speaker != self.last_detected_speaker:
|
||||
# if not self.speaker_languages.get(last_segment.speaker, None):
|
||||
# self.last_detected_speaker = last_segment.speaker
|
||||
# self.online.on_new_speaker(last_segment)
|
||||
|
||||
if self.diarization_before_transcription:
|
||||
if segments and segments[-1].speaker != self.current_speaker:
|
||||
self.current_speaker = segments[-1].speaker
|
||||
cut_at = int(segments[-1].start*16000 - (self.cumulative_pcm_len))
|
||||
await self.transcription_queue.put(pcm_array[cut_at:])
|
||||
await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=cut_at))
|
||||
await self.transcription_queue.put(pcm_array[:cut_at])
|
||||
else:
|
||||
await self.transcription_queue.put(pcm_array)
|
||||
else:
|
||||
async with self.lock:
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
self.cumulative_pcm_len += len(pcm_array)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
@@ -340,7 +340,7 @@ class AudioProcessor:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self, online_translation):
|
||||
async def translation_processor(self):
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
@@ -352,7 +352,7 @@ class AudioProcessor:
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
online_translation.insert_silence(item.duration)
|
||||
self.online_translation.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
# get all the available tokens for translation. The more words, the more precise
|
||||
@@ -366,9 +366,8 @@ class AudioProcessor:
|
||||
break
|
||||
tokens_to_process.append(additional_token)
|
||||
if tokens_to_process:
|
||||
online_translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = await asyncio.to_thread(online_translation.process)
|
||||
|
||||
self.online_translation.insert_tokens(tokens_to_process)
|
||||
self.translated_segments = await asyncio.to_thread(self.online_translation.process)
|
||||
self.translation_queue.task_done()
|
||||
for _ in additional_tokens:
|
||||
self.translation_queue.task_done()
|
||||
@@ -445,8 +444,8 @@ class AudioProcessor:
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text,
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization.strip(),
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
@@ -505,8 +504,8 @@ class AudioProcessor:
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
if self.args.target_language and self.args.lan != 'auto':
|
||||
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation))
|
||||
if self.online_translation:
|
||||
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
@@ -628,7 +627,7 @@ class AudioProcessor:
|
||||
silence_buffer = Silence(duration=time() - self.start_silence)
|
||||
|
||||
if silence_buffer:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
@@ -636,7 +635,7 @@ class AudioProcessor:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
|
||||
if not self.silence:
|
||||
if self.args.transcription and self.transcription_queue:
|
||||
if not self.diarization_before_transcription and self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_array.copy())
|
||||
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
|
||||
@@ -145,8 +145,8 @@ class TranscriptionEngine:
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto':
|
||||
raise Exception('Translation cannot be set with language auto')
|
||||
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
from whisperlivekit.translation.translation import load_model
|
||||
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
@@ -93,14 +93,16 @@ class SimulStreamingOnlineProcessor:
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def on_new_speaker(self, last_segment: SpeakerSegment):
|
||||
self.model.on_new_speaker(last_segment)
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
@@ -66,7 +66,7 @@ class PaddedAlignAttWhisper:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
@@ -152,7 +152,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.second_word_timestamp = None
|
||||
self.first_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
@@ -432,9 +432,9 @@ class PaddedAlignAttWhisper:
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.second_word_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.second_word_timestamp
|
||||
if seconds_since_start >= 5.0:
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
@@ -445,8 +445,6 @@ class PaddedAlignAttWhisper:
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
else:
|
||||
logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
@@ -591,8 +589,8 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
if len(l_absolute_timestamps) >=2 and self.second_word_timestamp is None:
|
||||
self.second_word_timestamp = l_absolute_timestamps[1]
|
||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||
self.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
|
||||
timestamped_words = []
|
||||
@@ -609,10 +607,11 @@ class PaddedAlignAttWhisper:
|
||||
end=current_timestamp + 0.1,
|
||||
text= word,
|
||||
probability=0.95,
|
||||
speaker=self.speaker,
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
return timestamped_words
|
||||
return timestamped_words
|
||||
@@ -160,13 +160,17 @@ class FrontData():
|
||||
if self.error:
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
start: int
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list
|
||||
translated_segments: list
|
||||
buffer_transcription: str
|
||||
buffer_diarization: str
|
||||
end_buffer: float
|
||||
end_attributed_speaker: float
|
||||
remaining_time_transcription: float
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
import ctranslate2
|
||||
import torch
|
||||
import transformers
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import huggingface_hub
|
||||
from whisperlivekit.translation.mapping_languages import get_nllb_code
|
||||
from whisperlivekit.timed_objects import Translation
|
||||
@@ -18,9 +18,20 @@ MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider t
|
||||
@dataclass
|
||||
class TranslationModel():
|
||||
translator: ctranslate2.Translator
|
||||
tokenizer: dict
|
||||
device: str
|
||||
tokenizer: dict = field(default_factory=dict)
|
||||
backend_type: str = 'ctranslate2'
|
||||
model_size: str = '600M'
|
||||
|
||||
def get_tokenizer(self, input_lang):
|
||||
if not self.tokenizer.get(input_lang, False):
|
||||
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
|
||||
f"facebook/nllb-200-distilled-{self.model_size}",
|
||||
src_lang=input_lang,
|
||||
clean_up_tokenization_spaces=True
|
||||
)
|
||||
return self.tokenizer[input_lang]
|
||||
|
||||
|
||||
def load_model(src_langs, backend='ctranslate2', model_size='600M'):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -33,14 +44,20 @@ def load_model(src_langs, backend='ctranslate2', model_size='600M'):
|
||||
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
|
||||
tokenizer = dict()
|
||||
for src_lang in src_langs:
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
if src_lang != 'auto':
|
||||
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
|
||||
|
||||
return TranslationModel(
|
||||
translation_model = TranslationModel(
|
||||
translator=translator,
|
||||
tokenizer=tokenizer,
|
||||
backend_type=backend,
|
||||
device = device
|
||||
device = device,
|
||||
model_size = model_size
|
||||
)
|
||||
for src_lang in src_langs:
|
||||
if src_lang != 'auto':
|
||||
translation_model.get_tokenizer(src_lang)
|
||||
return translation_model
|
||||
|
||||
class OnlineTranslation:
|
||||
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
|
||||
@@ -63,16 +80,12 @@ class OnlineTranslation:
|
||||
self.commited.extend(self.buffer[:i])
|
||||
self.buffer = results[i:]
|
||||
|
||||
def translate(self, input, input_lang=None, output_lang=None):
|
||||
def translate(self, input, input_lang, output_lang):
|
||||
if not input:
|
||||
return ""
|
||||
if input_lang is None:
|
||||
input_lang = self.input_languages[0]
|
||||
if output_lang is None:
|
||||
output_lang = self.output_languages[0]
|
||||
nllb_output_lang = get_nllb_code(output_lang)
|
||||
|
||||
tokenizer = self.translation_model.tokenizer[input_lang]
|
||||
tokenizer = self.translation_model.get_tokenizer(input_lang)
|
||||
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
|
||||
|
||||
if self.translation_model.backend_type == 'ctranslate2':
|
||||
@@ -90,7 +103,15 @@ class OnlineTranslation:
|
||||
text = ' '.join([token.text for token in tokens])
|
||||
start = tokens[0].start
|
||||
end = tokens[-1].end
|
||||
translated_text = self.translate(text)
|
||||
if self.input_languages[0] == 'auto':
|
||||
input_lang = tokens[0].detected_language
|
||||
else:
|
||||
input_lang = self.input_languages[0]
|
||||
|
||||
translated_text = self.translate(text,
|
||||
input_lang,
|
||||
self.output_languages[0]
|
||||
)
|
||||
translation = Translation(
|
||||
text=translated_text,
|
||||
start=start,
|
||||
|
||||
Reference in New Issue
Block a user