5 Commits

Author SHA1 Message Date
Quentin Fuxa
d24c110d55 to 0.2.11 2025-09-24 22:34:01 +02:00
Quentin Fuxa
4dd5d8bf8a translation compatible with auto and detected language 2025-09-22 11:20:00 +02:00
Quentin Fuxa
93f002cafb language detection after few seconds working 2025-09-20 11:08:00 +02:00
Quentin Fuxa
c5e30c2c07 svg loaded once in javascript, no more need for StaticFiles 2025-09-20 11:06:00 +02:00
Quentin Fuxa
1c2afb8bd2 svg loaded once in javascript, no more need for StaticFiles 2025-09-20 11:06:00 +02:00
14 changed files with 134 additions and 130 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 368 KiB

After

Width:  |  Height:  |  Size: 390 KiB

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.10" version = "0.2.11"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@@ -4,7 +4,7 @@ from time import time, sleep
import math import math
import logging import logging
import traceback import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output from whisperlivekit.results_formater import format_output
@@ -59,7 +59,6 @@ class AudioProcessor:
self.tokens = [] self.tokens = []
self.translated_segments = [] self.translated_segments = []
self.buffer_transcription = Transcript() self.buffer_transcription = Transcript()
self.buffer_diarization = ""
self.end_buffer = 0 self.end_buffer = 0
self.end_attributed_speaker = 0 self.end_attributed_speaker = 0
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
@@ -68,7 +67,9 @@ class AudioProcessor:
self.last_response_content = FrontData() self.last_response_content = FrontData()
self.last_detected_speaker = None self.last_detected_speaker = None
self.speaker_languages = {} self.speaker_languages = {}
self.cumulative_pcm_len = 0
self.diarization_before_transcription = False
# Models and processing # Models and processing
self.asr = models.asr self.asr = models.asr
self.tokenizer = models.tokenizer self.tokenizer = models.tokenizer
@@ -101,13 +102,14 @@ class AudioProcessor:
self.diarization_task = None self.diarization_task = None
self.watchdog_task = None self.watchdog_task = None
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
self.online_translation = None
if self.args.transcription: if self.args.transcription:
self.online = online_factory(self.args, models.asr, models.tokenizer) self.online = online_factory(self.args, models.asr, models.tokenizer)
self.sep = self.online.asr.sep self.sep = self.online.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if self.args.target_language: if models.translation_model:
self.online_translation = online_translation_factory(self.args, models.translation_model) self.online_translation = online_translation_factory(self.args, models.translation_model)
def convert_pcm_to_float(self, pcm_buffer): def convert_pcm_to_float(self, pcm_buffer):
@@ -142,7 +144,6 @@ class AudioProcessor:
tokens=self.tokens.copy(), tokens=self.tokens.copy(),
translated_segments=self.translated_segments.copy(), translated_segments=self.translated_segments.copy(),
buffer_transcription=self.buffer_transcription, buffer_transcription=self.buffer_transcription,
buffer_diarization=self.buffer_diarization,
end_buffer=self.end_buffer, end_buffer=self.end_buffer,
end_attributed_speaker=self.end_attributed_speaker, end_attributed_speaker=self.end_attributed_speaker,
remaining_time_transcription=remaining_transcription, remaining_time_transcription=remaining_transcription,
@@ -154,7 +155,7 @@ class AudioProcessor:
async with self.lock: async with self.lock:
self.tokens = [] self.tokens = []
self.translated_segments = [] self.translated_segments = []
self.buffer_transcription = self.buffer_diarization = Transcript() self.buffer_transcription = Transcript()
self.end_buffer = self.end_attributed_speaker = 0 self.end_buffer = self.end_attributed_speaker = 0
self.beg_loop = time() self.beg_loop = time()
@@ -201,11 +202,11 @@ class AudioProcessor:
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.") logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
if self.args.transcription and self.transcription_queue: if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(SENTINEL) await self.diarization_queue.put(SENTINEL)
if self.args.target_language and self.translation_queue: if self.online_translation:
await self.translation_queue.put(SENTINEL) await self.translation_queue.put(SENTINEL)
async def transcription_processor(self): async def transcription_processor(self):
@@ -219,11 +220,6 @@ class AudioProcessor:
logger.debug("Transcription processor received sentinel. Finishing.") logger.debug("Transcription processor received sentinel. Finishing.")
self.transcription_queue.task_done() self.transcription_queue.task_done()
break break
if not self.online:
logger.warning("Transcription processor: self.online not initialized.")
self.transcription_queue.task_done()
continue
asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE asr_internal_buffer_duration_s = len(getattr(self.online, 'audio_buffer', [])) / self.online.SAMPLING_RATE
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
@@ -236,12 +232,12 @@ class AudioProcessor:
cumulative_pcm_duration_stream_time += item.duration cumulative_pcm_duration_stream_time += item.duration
self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0) self.online.insert_silence(item.duration, self.tokens[-1].end if self.tokens else 0)
continue continue
logger.info(asr_processing_logs) elif isinstance(item, ChangeSpeaker):
self.online.new_speaker(item)
if isinstance(item, np.ndarray): elif isinstance(item, np.ndarray):
pcm_array = item pcm_array = item
else:
raise Exception('item should be pcm_array') logger.info(asr_processing_logs)
duration_this_chunk = len(pcm_array) / self.sample_rate duration_this_chunk = len(pcm_array) / self.sample_rate
cumulative_pcm_duration_stream_time += duration_this_chunk cumulative_pcm_duration_stream_time += duration_this_chunk
@@ -297,8 +293,7 @@ class AudioProcessor:
async def diarization_processor(self, diarization_obj): async def diarization_processor(self, diarization_obj):
"""Process audio chunks for speaker diarization.""" """Process audio chunks for speaker diarization."""
buffer_diarization = "" self.current_speaker = 0
cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
item = await self.diarization_queue.get() item = await self.diarization_queue.get()
@@ -307,7 +302,6 @@ class AudioProcessor:
self.diarization_queue.task_done() self.diarization_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
cumulative_pcm_duration_stream_time += item.duration
diarization_obj.insert_silence(item.duration) diarization_obj.insert_silence(item.duration)
continue continue
elif isinstance(item, np.ndarray): elif isinstance(item, np.ndarray):
@@ -317,22 +311,26 @@ class AudioProcessor:
# Process diarization # Process diarization
await diarization_obj.diarize(pcm_array) await diarization_obj.diarize(pcm_array)
segments = diarization_obj.get_segments()
async with self.lock: if self.diarization_before_transcription:
self.tokens, last_segment = diarization_obj.assign_speakers_to_tokens( if segments and segments[-1].speaker != self.current_speaker:
self.tokens, self.current_speaker = segments[-1].speaker
use_punctuation_split=self.args.punctuation_split cut_at = int(segments[-1].start*16000 - (self.cumulative_pcm_len))
) await self.transcription_queue.put(pcm_array[cut_at:])
if len(self.tokens) > 0: await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=cut_at))
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker) await self.transcription_queue.put(pcm_array[:cut_at])
if buffer_diarization: else:
self.buffer_diarization = buffer_diarization await self.transcription_queue.put(pcm_array)
else:
# if last_segment is not None and last_segment.speaker != self.last_detected_speaker: async with self.lock:
# if not self.speaker_languages.get(last_segment.speaker, None): self.tokens = diarization_obj.assign_speakers_to_tokens(
# self.last_detected_speaker = last_segment.speaker self.tokens,
# self.online.on_new_speaker(last_segment) use_punctuation_split=self.args.punctuation_split
)
self.cumulative_pcm_len += len(pcm_array)
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
self.diarization_queue.task_done() self.diarization_queue.task_done()
except Exception as e: except Exception as e:
@@ -342,7 +340,7 @@ class AudioProcessor:
self.diarization_queue.task_done() self.diarization_queue.task_done()
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self, online_translation): async def translation_processor(self):
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
@@ -354,7 +352,7 @@ class AudioProcessor:
self.translation_queue.task_done() self.translation_queue.task_done()
break break
elif type(item) is Silence: elif type(item) is Silence:
online_translation.insert_silence(item.duration) self.online_translation.insert_silence(item.duration)
continue continue
# get all the available tokens for translation. The more words, the more precise # get all the available tokens for translation. The more words, the more precise
@@ -368,9 +366,8 @@ class AudioProcessor:
break break
tokens_to_process.append(additional_token) tokens_to_process.append(additional_token)
if tokens_to_process: if tokens_to_process:
online_translation.insert_tokens(tokens_to_process) self.online_translation.insert_tokens(tokens_to_process)
self.translated_segments = await asyncio.to_thread(online_translation.process) self.translated_segments = await asyncio.to_thread(self.online_translation.process)
self.translation_queue.task_done() self.translation_queue.task_done()
for _ in additional_tokens: for _ in additional_tokens:
self.translation_queue.task_done() self.translation_queue.task_done()
@@ -423,23 +420,15 @@ class AudioProcessor:
) )
if end_w_silence: if end_w_silence:
buffer_transcription = Transcript() buffer_transcription = Transcript()
buffer_diarization = Transcript()
else: else:
buffer_transcription = state.buffer_transcription buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
# Handle undiarized text buffer_diarization = ''
if undiarized_text: if undiarized_text:
combined = self.sep.join(undiarized_text) buffer_diarization = self.sep.join(undiarized_text)
if buffer_transcription:
combined += self.sep
async with self.lock: async with self.lock:
self.end_attributed_speaker = state.end_attributed_speaker self.end_attributed_speaker = state.end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
buffer_diarization.text = combined
response_status = "active_transcription" response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization: if not state.tokens and not buffer_transcription and not buffer_diarization:
@@ -455,8 +444,8 @@ class AudioProcessor:
response = FrontData( response = FrontData(
status=response_status, status=response_status,
lines=lines, lines=lines,
buffer_transcription=buffer_transcription.text, buffer_transcription=buffer_transcription.text.strip(),
buffer_diarization=buffer_transcription.text, buffer_diarization=buffer_diarization.strip(),
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
@@ -515,8 +504,8 @@ class AudioProcessor:
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.args.target_language and self.args.lan != 'auto': if self.online_translation:
self.translation_task = asyncio.create_task(self.translation_processor(self.online_translation)) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
@@ -638,7 +627,7 @@ class AudioProcessor:
silence_buffer = Silence(duration=time() - self.start_silence) silence_buffer = Silence(duration=time() - self.start_silence)
if silence_buffer: if silence_buffer:
if self.args.transcription and self.transcription_queue: if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(silence_buffer) await self.transcription_queue.put(silence_buffer)
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:
await self.diarization_queue.put(silence_buffer) await self.diarization_queue.put(silence_buffer)
@@ -646,7 +635,7 @@ class AudioProcessor:
await self.translation_queue.put(silence_buffer) await self.translation_queue.put(silence_buffer)
if not self.silence: if not self.silence:
if self.args.transcription and self.transcription_queue: if not self.diarization_before_transcription and self.transcription_queue:
await self.transcription_queue.put(pcm_array.copy()) await self.transcription_queue.put(pcm_array.copy())
if self.args.diarization and self.diarization_queue: if self.args.diarization and self.diarization_queue:

View File

@@ -5,9 +5,6 @@ from fastapi.middleware.cors import CORSMiddleware
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
import asyncio import asyncio
import logging import logging
from starlette.staticfiles import StaticFiles
import pathlib
import whisperlivekit.web as webpkg
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
@@ -33,8 +30,6 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
web_dir = pathlib.Path(webpkg.__file__).parent
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
@app.get("/") @app.get("/")
async def get(): async def get():

View File

@@ -145,8 +145,8 @@ class TranscriptionEngine:
self.translation_model = None self.translation_model = None
if self.args.target_language: if self.args.target_language:
if self.args.lan == 'auto': if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
raise Exception('Translation cannot be set with language auto') raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
else: else:
from whisperlivekit.translation.translation import load_model from whisperlivekit.translation.translation import load_model
self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers

View File

@@ -242,7 +242,7 @@ class DiartDiarization:
token.speaker = extract_number(segment.speaker) + 1 token.speaker = extract_number(segment.speaker) + 1
else: else:
tokens = add_speaker_to_tokens(segments, tokens) tokens = add_speaker_to_tokens(segments, tokens)
return tokens, segments[-1] return tokens
def concatenate_speakers(segments): def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]

View File

@@ -296,7 +296,7 @@ class SortformerDiarizationOnline:
if not segments or not tokens: if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment") logger.debug("No segments or tokens available for speaker assignment")
return tokens, None return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False use_punctuation_split = False
@@ -313,7 +313,7 @@ class SortformerDiarizationOnline:
# Use punctuation-aware assignment (similar to diart_backend) # Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens) tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens, segments[-1] return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list: def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
""" """

View File

@@ -38,12 +38,16 @@ def new_line(
text = token.text + debug_info, text = token.text + debug_info,
start = token.start, start = token.start,
end = token.end, end = token.end,
detected_language=token.detected_language
) )
def append_token_to_last_line(lines, sep, token, debug_info): def append_token_to_last_line(lines, sep, token, debug_info):
if token.text: if token.text:
lines[-1].text += sep + token.text + debug_info lines[-1].text += sep + token.text + debug_info
lines[-1].end = token.end lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def format_output(state, silence, current_time, args, debug, sep): def format_output(state, silence, current_time, args, debug, sep):
diarization = args.diarization diarization = args.diarization

View File

@@ -4,7 +4,7 @@ import logging
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import logging import logging
import platform import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
from whisperlivekit.warmup import load_file from whisperlivekit.warmup import load_file
from .whisper import load_model, tokenizer from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND from .whisper.audio import TOKENS_PER_SECOND
@@ -93,14 +93,16 @@ class SimulStreamingOnlineProcessor:
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor) self.model.insert_audio(audio_tensor)
def on_new_speaker(self, last_segment: SpeakerSegment): def new_speaker(self, change_speaker: ChangeSpeaker):
self.model.on_new_speaker(last_segment) self.process_iter(is_last=True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
self.model.speaker = change_speaker.speaker
self.global_time_offset = change_speaker.start
def get_buffer(self): def get_buffer(self):
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
return concat_buffer return concat_buffer
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
""" """
Process accumulated audio chunks using SimulStreaming. Process accumulated audio chunks using SimulStreaming.
@@ -108,9 +110,13 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
""" """
try: try:
timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last) timestamped_words = self.model.infer(is_last=is_last)
self.buffer = timestamped_buffer_language if timestamped_words and timestamped_words[0].detected_language == None:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words) self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end

View File

@@ -66,7 +66,7 @@ class PaddedAlignAttWhisper:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions( self.decode_options = DecodingOptions(
language = cfg.language, language = cfg.language,
without_timestamps = True, without_timestamps = True,
@@ -78,7 +78,6 @@ class PaddedAlignAttWhisper:
self.detected_language = cfg.language if cfg.language != "auto" else None self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0 self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False self.reset_tokenizer_to_auto_next_call = False
self.sentence_start_time = 0.0
self.max_text_len = self.model.dims.n_text_ctx self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks) self.num_decoder_layers = len(self.model.decoder.blocks)
@@ -153,7 +152,7 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0 self.cumulative_time_offset = 0.0
self.sentence_start_time = self.cumulative_time_offset + self.segments_len() self.first_timestamp = None
if self.cfg.max_context_tokens is None: if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len self.max_context_tokens = self.max_text_len
@@ -261,7 +260,6 @@ class PaddedAlignAttWhisper:
self.init_context() self.init_context()
logger.debug(f"Context: {self.context}") logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2: if not complete and len(self.segments) > 2:
logger.debug("keeping last two segments because they are and it is not complete.")
self.segments = self.segments[-2:] self.segments = self.segments[-2:]
else: else:
logger.debug("removing all segments.") logger.debug("removing all segments.")
@@ -434,18 +432,19 @@ class PaddedAlignAttWhisper:
end_encode = time() end_encode = time()
# print('Encoder duration:', end_encode-beg_encode) # print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None: if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time seconds_since_start = self.segments_len() - self.first_timestamp
if seconds_since_start >= 3.0: if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature) language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}") print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan) self.create_tokenizer(top_lan)
self.refresh_segment(complete=True) self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.detected_language = top_lan self.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
else:
logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
self.trim_context() self.trim_context()
current_tokens = self._current_tokens() current_tokens = self._current_tokens()
@@ -590,6 +589,10 @@ class PaddedAlignAttWhisper:
self._clean_cache() self._clean_cache()
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
self.first_timestamp = l_absolute_timestamps[0]
timestamped_words = [] timestamped_words = []
timestamp_idx = 0 timestamp_idx = 0
for word, word_tokens in zip(split_words, split_tokens): for word, word_tokens in zip(split_words, split_tokens):
@@ -604,15 +607,11 @@ class PaddedAlignAttWhisper:
end=current_timestamp + 0.1, end=current_timestamp + 0.1,
text= word, text= word,
probability=0.95, probability=0.95,
language=self.detected_language speaker=self.speaker,
detected_language=self.detected_language
).with_offset( ).with_offset(
self.global_time_offset self.global_time_offset
) )
timestamped_words.append(timestamp_entry) timestamped_words.append(timestamp_entry)
if self.detected_language is None and self.cfg.language == "auto": return timestamped_words
timestamped_buffer_language, timestamped_words = timestamped_words, []
else:
timestamped_buffer_language = []
return timestamped_words, timestamped_buffer_language

View File

@@ -17,7 +17,7 @@ class TimedText:
speaker: Optional[int] = -1 speaker: Optional[int] = -1
probability: Optional[float] = None probability: Optional[float] = None
is_dummy: Optional[bool] = False is_dummy: Optional[bool] = False
language: str = None detected_language: Optional[str] = None
def is_punctuation(self): def is_punctuation(self):
return self.text.strip() in PUNCTUATION_MARKS return self.text.strip() in PUNCTUATION_MARKS
@@ -41,11 +41,11 @@ class TimedText:
return bool(self.text) return bool(self.text)
@dataclass @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
@dataclass @dataclass
class Sentence(TimedText): class Sentence(TimedText):
@@ -123,7 +123,6 @@ class Silence():
@dataclass @dataclass
class Line(TimedText): class Line(TimedText):
translation: str = '' translation: str = ''
detected_language: str = None
def to_dict(self): def to_dict(self):
_dict = { _dict = {
@@ -161,13 +160,17 @@ class FrontData():
if self.error: if self.error:
_dict['error'] = self.error _dict['error'] = self.error
return _dict return _dict
@dataclass
class ChangeSpeaker:
speaker: int
start: int
@dataclass @dataclass
class State(): class State():
tokens: list tokens: list
translated_segments: list translated_segments: list
buffer_transcription: str buffer_transcription: str
buffer_diarization: str
end_buffer: float end_buffer: float
end_attributed_speaker: float end_attributed_speaker: float
remaining_time_transcription: float remaining_time_transcription: float

View File

@@ -3,7 +3,7 @@ import time
import ctranslate2 import ctranslate2
import torch import torch
import transformers import transformers
from dataclasses import dataclass from dataclasses import dataclass, field
import huggingface_hub import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation from whisperlivekit.timed_objects import Translation
@@ -18,9 +18,20 @@ MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider t
@dataclass @dataclass
class TranslationModel(): class TranslationModel():
translator: ctranslate2.Translator translator: ctranslate2.Translator
tokenizer: dict
device: str device: str
tokenizer: dict = field(default_factory=dict)
backend_type: str = 'ctranslate2' backend_type: str = 'ctranslate2'
model_size: str = '600M'
def get_tokenizer(self, input_lang):
if not self.tokenizer.get(input_lang, False):
self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
f"facebook/nllb-200-distilled-{self.model_size}",
src_lang=input_lang,
clean_up_tokenization_spaces=True
)
return self.tokenizer[input_lang]
def load_model(src_langs, backend='ctranslate2', model_size='600M'): def load_model(src_langs, backend='ctranslate2', model_size='600M'):
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -33,14 +44,20 @@ def load_model(src_langs, backend='ctranslate2', model_size='600M'):
translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}") translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
tokenizer = dict() tokenizer = dict()
for src_lang in src_langs: for src_lang in src_langs:
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True) if src_lang != 'auto':
tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)
return TranslationModel( translation_model = TranslationModel(
translator=translator, translator=translator,
tokenizer=tokenizer, tokenizer=tokenizer,
backend_type=backend, backend_type=backend,
device = device device = device,
model_size = model_size
) )
for src_lang in src_langs:
if src_lang != 'auto':
translation_model.get_tokenizer(src_lang)
return translation_model
class OnlineTranslation: class OnlineTranslation:
def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list): def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
@@ -63,16 +80,12 @@ class OnlineTranslation:
self.commited.extend(self.buffer[:i]) self.commited.extend(self.buffer[:i])
self.buffer = results[i:] self.buffer = results[i:]
def translate(self, input, input_lang=None, output_lang=None): def translate(self, input, input_lang, output_lang):
if not input: if not input:
return "" return ""
if input_lang is None:
input_lang = self.input_languages[0]
if output_lang is None:
output_lang = self.output_languages[0]
nllb_output_lang = get_nllb_code(output_lang) nllb_output_lang = get_nllb_code(output_lang)
tokenizer = self.translation_model.tokenizer[input_lang] tokenizer = self.translation_model.get_tokenizer(input_lang)
tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device) tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
if self.translation_model.backend_type == 'ctranslate2': if self.translation_model.backend_type == 'ctranslate2':
@@ -90,7 +103,15 @@ class OnlineTranslation:
text = ' '.join([token.text for token in tokens]) text = ' '.join([token.text for token in tokens])
start = tokens[0].start start = tokens[0].start
end = tokens[-1].end end = tokens[-1].end
translated_text = self.translate(text) if self.input_languages[0] == 'auto':
input_lang = tokens[0].detected_language
else:
input_lang = self.input_languages[0]
translated_text = self.translate(text,
input_lang,
self.output_languages[0]
)
translation = Translation( translation = Translation(
text=translated_text, text=translated_text,
start=start, start=start,

View File

@@ -534,22 +534,6 @@ label {
color: var(--muted); color: var(--muted);
} }
.label_language img {
width: 12px;
height: 12px;
}
.silence-icon {
width: 14px;
height: 14px;
vertical-align: text-bottom;
}
.speaker-icon {
width: 16px;
height: 16px;
vertical-align: text-bottom;
}
.speaker-badge { .speaker-badge {
display: inline-flex; display: inline-flex;

View File

@@ -40,6 +40,11 @@ const timerElement = document.querySelector(".timer");
const themeRadios = document.querySelectorAll('input[name="theme"]'); const themeRadios = document.querySelectorAll('input[name="theme"]');
const microphoneSelect = document.getElementById("microphoneSelect"); const microphoneSelect = document.getElementById("microphoneSelect");
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
function getWaveStroke() { function getWaveStroke() {
const styles = getComputedStyle(document.documentElement); const styles = getComputedStyle(document.documentElement);
const v = styles.getPropertyValue("--wave-stroke").trim(); const v = styles.getPropertyValue("--wave-stroke").trim();
@@ -335,19 +340,17 @@ function renderLinesWithBuffer(
let speakerLabel = ""; let speakerLabel = "";
if (item.speaker === -2) { if (item.speaker === -2) {
const silenceIcon = `<img class="silence-icon" src="/web/src/silence.svg" alt="Silence" />`;
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0 && !isFinalizing) { } else if (item.speaker == 0 && !isFinalizing) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1( speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
remaining_time_diarization remaining_time_diarization
)}</span> second(s) of audio are undergoing diarization</span></span>`; )}</span> second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker !== 0) { } else if (item.speaker !== 0) {
const speakerIcon = `<img class="speaker-icon" src="/web/src/speaker.svg" alt="Speaker ${item.speaker}" />`;
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`; const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`; speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
if (item.detected_language) { if (item.detected_language) {
speakerLabel += `<span class="label_language"><img src="/web/src/language.svg" alt="Detected language" width="12" height="12" /><span>${item.detected_language}</span></span>`; speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
} }
} }
@@ -388,7 +391,7 @@ function renderLinesWithBuffer(
if (item.translation) { if (item.translation) {
currentLineText += `<div class="label_translation"> currentLineText += `<div class="label_translation">
<img src="/web/src/translate.svg" alt="Translation" width="12" height="12" /> ${translationIcon}
<span>${item.translation}</span> <span>${item.translation}</span>
</div>`; </div>`;
} }