diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index ed87dab..2ffbc0f 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -429,7 +429,7 @@ class AudioProcessor: state = await self.get_current_state() # Format output - lines, undiarized_text, buffer_transcription, buffer_diarization = format_output( + lines, undiarized_text, end_w_silence = format_output( state, self.silence, current_time = time() - self.beg_loop if self.beg_loop else None, @@ -437,6 +437,13 @@ class AudioProcessor: debug = self.debug, sep=self.sep ) + if end_w_silence: + buffer_transcription = '' + buffer_diarization = '' + else: + buffer_transcription = state.buffer_transcription + buffer_diarization = state.buffer_diarization + # Handle undiarized text if undiarized_text: combined = self.sep.join(undiarized_text) diff --git a/whisperlivekit/remove_silences.py b/whisperlivekit/remove_silences.py index 3e4edb1..cdbb442 100644 --- a/whisperlivekit/remove_silences.py +++ b/whisperlivekit/remove_silences.py @@ -77,15 +77,17 @@ def no_token_to_silence(tokens): new_tokens.append(token) return new_tokens -def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence): +def ends_with_silence(tokens, current_time, vac_detected_silence): + end_w_silence = False if not tokens: - return [], buffer_transcription, buffer_diarization + return [], end_w_silence last_token = tokens[-1] if tokens and current_time and ( current_time - last_token.end >= END_SILENCE_DURATION - or + or (current_time - last_token.end >= 3 and vac_detected_silence) ): + end_w_silence = True if last_token.speaker == -2: last_token.end = current_time else: @@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_ probability=0.95 ) ) - buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence - buffer_diarization = "" - return tokens, buffer_transcription, buffer_diarization + return tokens, end_w_silence -def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence): +def handle_silences(tokens, current_time, vac_detected_silence): tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text tokens = no_token_to_silence(tokens) - tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence) - return tokens, buffer_transcription, buffer_diarization + tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence) + return tokens, end_w_silence \ No newline at end of file diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 577a1c5..8f30a9a 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -50,14 +50,12 @@ def format_output(state, silence, current_time, args, debug, sep): disable_punctuation_split = args.disable_punctuation_split tokens = state.tokens translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments - buffer_transcription = state.buffer_transcription - buffer_diarization = state.buffer_diarization end_attributed_speaker = state.end_attributed_speaker previous_speaker = -1 lines = [] undiarized_text = [] - tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence) + tokens, end_w_silence = handle_silences(tokens, current_time, silence) last_punctuation = None for i, token in enumerate(tokens): speaker = token.speaker @@ -121,6 +119,7 @@ def format_output(state, silence, current_time, args, debug, sep): pass append_token_to_last_line(lines, sep, token, debug_info) + if lines and translated_segments: unassigned_translated_segments = [] for ts in translated_segments: @@ -151,4 +150,4 @@ def format_output(state, silence, current_time, args, debug, sep): else: remaining_segments.append(ts) unassigned_translated_segments = remaining_segments #maybe do smth in the future about that - return lines, undiarized_text, buffer_transcription, '' + return lines, undiarized_text, end_w_silence diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index e1e12e2..ab648db 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -6,7 +6,6 @@ import logging import platform from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment from whisperlivekit.warmup import load_file -from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from .whisper import load_model, tokenizer from .whisper.audio import TOKENS_PER_SECOND import os @@ -23,7 +22,11 @@ try: HAS_MLX_WHISPER = True except ImportError: if platform.system() == "Darwin" and platform.machine() == "arm64": - print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper') + print(f""" + {"="*50} + MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper + {"="*50} + """) HAS_MLX_WHISPER = False if HAS_MLX_WHISPER: HAS_FASTER_WHISPER = False @@ -49,8 +52,12 @@ class SimulStreamingOnlineProcessor: self.asr = asr self.logfile = logfile self.end = 0.0 - self.global_time_offset = 0.0 - + self.buffer = Transcript( + start=None, + end=None, + text='', + probability=None + ) self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] self.load_new_backend() @@ -79,7 +86,7 @@ class SimulStreamingOnlineProcessor: else: self.process_iter(is_last=True) #we want to totally process what remains in the buffer. self.model.refresh_segment(complete=True) - self.global_time_offset = silence_duration + offset + self.model.global_time_offset = silence_duration + offset @@ -96,31 +103,7 @@ class SimulStreamingOnlineProcessor: self.model.refresh_segment(complete=True) def get_buffer(self): - return Transcript( - start=None, - end=None, - text='', - probability=None - ) - - def timestamped_text(self, split_words, split_tokens, l_absolute_timestamps): - timestamped_words = [] - - for word, word_tokens in zip(split_words, split_tokens): - - for i in word_tokens: - current_timestamp = l_absolute_timestamps.pop(0) - - timestamp_entry = ASRToken( - start=current_timestamp, - end=current_timestamp + 0.1, - text=word, - probability=0.95 - ).with_offset( - self.global_time_offset - ) - timestamped_words.append(timestamp_entry) - return timestamped_words + return self.buffer def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """ @@ -129,9 +112,7 @@ class SimulStreamingOnlineProcessor: Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). """ try: - split_words, split_tokens, l_absolute_timestamps = self.model.infer(is_last=is_last) - new_tokens = self.timestamped_text(split_words, split_tokens, l_absolute_timestamps) - + new_tokens = self.model.infer(is_last=is_last) self.committed.extend(new_tokens) return new_tokens, self.end @@ -163,7 +144,6 @@ class SimulStreamingASR(): sep = "" def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): - logger.warning(SIMULSTREAMING_LICENSE) self.logfile = logfile self.transcribe_kargs = {} self.original_language = lan diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 49468e6..b23a57c 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from .whisper import load_model, DecodingOptions, tokenizer from .config import AlignAttConfig +from whisperlivekit.timed_objects import ASRToken from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES from .whisper.timing import median_filter from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language @@ -18,6 +19,7 @@ from time import time from .token_buffer import TokenBuffer import numpy as np +from ..timed_objects import PUNCTUATION_MARKS from .generation_progress import * DEC_PAD = 50257 @@ -40,12 +42,6 @@ else: except ImportError: HAS_FASTER_WHISPER = False -# New features added to the original version of Simul-Whisper: -# - large-v3 model support -# - translation support -# - beam search -# - prompt -- static vs. non-static -# - context class PaddedAlignAttWhisper: def __init__( self, @@ -79,6 +75,9 @@ class PaddedAlignAttWhisper: self.tokenizer_is_multilingual = not model_name.endswith(".en") self.create_tokenizer(cfg.language if cfg.language != "auto" else None) self.detected_language = cfg.language if cfg.language != "auto" else None + self.global_time_offset = 0.0 + self.reset_tokenizer_to_auto_next_call = False + self.sentence_start_time = 0.0 self.max_text_len = self.model.dims.n_text_ctx self.num_decoder_layers = len(self.model.decoder.blocks) @@ -153,6 +152,7 @@ class PaddedAlignAttWhisper: self.last_attend_frame = -self.cfg.rewind_threshold self.cumulative_time_offset = 0.0 + self.sentence_start_time = self.cumulative_time_offset + self.segments_len() if self.cfg.max_context_tokens is None: self.max_context_tokens = self.max_text_len @@ -382,11 +382,11 @@ class PaddedAlignAttWhisper: new_segment = True if len(self.segments) == 0: logger.debug("No segments, nothing to do") - return [], [], [] + return [] if not self._apply_minseglen(): logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") input_segments = torch.cat(self.segments, dim=0) - return [], [], [] + return [] # input_segments is concatenation of audio, it's one array if len(self.segments) > 1: @@ -394,6 +394,13 @@ class PaddedAlignAttWhisper: else: input_segments = self.segments[0] + # if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call: + # logger.debug("Resetting tokenizer to auto for new sentence.") + # self.create_tokenizer(None) + # self.detected_language = None + # self.init_tokens() + # self.reset_tokenizer_to_auto_next_call = False + # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder beg_encode = time() if self.mlx_encoder: @@ -426,17 +433,21 @@ class PaddedAlignAttWhisper: end_encode = time() # print('Encoder duration:', end_encode-beg_encode) - if self.cfg.language == "auto" and self.detected_language is None: - language_tokens, language_probs = self.lang_id(encoder_feature) - logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") - top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) - logger.info(f"Detected language: {top_lan} with p={p:.4f}") - #self.tokenizer.language = top_lan - #self.tokenizer.__post_init__() - self.create_tokenizer(top_lan) - self.detected_language = top_lan - self.init_tokens() - logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") + # if self.cfg.language == "auto" and self.detected_language is None: + # seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time + # if seconds_since_start >= 3.0: + # language_tokens, language_probs = self.lang_id(encoder_feature) + # logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") + # top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) + # logger.info(f"Detected language: {top_lan} with p={p:.4f}") + # #self.tokenizer.language = top_lan + # #self.tokenizer.__post_init__() + # self.create_tokenizer(top_lan) + # self.detected_language = top_lan + # self.init_tokens() + # logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") + # else: + # logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s") self.trim_context() current_tokens = self._current_tokens() @@ -446,6 +457,7 @@ class PaddedAlignAttWhisper: sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device) completed = False + # punctuation_stop = False attn_of_alignment_heads = None most_attended_frame = None @@ -467,9 +479,7 @@ class PaddedAlignAttWhisper: if new_segment and self.tokenizer.no_speech is not None: probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1) no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() - # generation["no_speech_prob"] = no_speech_probs[0] if no_speech_probs[0] > self.cfg.nonspeech_prob: - # generation["no_speech"] = True logger.info("no speech, stop") break @@ -485,6 +495,19 @@ class PaddedAlignAttWhisper: logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") self.debug_print_tokens(current_tokens) + # # Early stop on sentence-ending punctuation when language is auto + # if not completed and self.cfg.language == "auto": + # last_token_id = current_tokens[0, -1].item() + # last_token_text = self.tokenizer.decode([last_token_id]).strip() + # if last_token_text in PUNCTUATION_MARKS: + # logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.") + # punctuation_stop = True + # # Ensure next call starts with auto language (re-detect for new sentence) + # self.reset_tokenizer_to_auto_next_call = True + # self.detected_language = None + # self.sentence_start_time = self.cumulative_time_offset + self.segments_len() + # break + attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] for i, attn_mat in enumerate(self.dec_attns): layer_rank = int(i % len(self.model.decoder.blocks)) @@ -560,7 +583,7 @@ class PaddedAlignAttWhisper: tokens_to_split = current_tokens[0, token_len_before_decoding:] - if fire_detected or is_last: + if fire_detected or is_last: #or punctuation_stop: new_hypothesis = tokens_to_split.flatten().tolist() split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) else: @@ -582,4 +605,22 @@ class PaddedAlignAttWhisper: self._clean_cache() - return split_words, split_tokens, l_absolute_timestamps \ No newline at end of file + timestamped_words = [] + timestamp_idx = 0 + for word, word_tokens in zip(split_words, split_tokens): + try: + current_timestamp = l_absolute_timestamps[timestamp_idx] + except: + pass + timestamp_idx += len(word_tokens) + + timestamp_entry = ASRToken( + start=current_timestamp, + end=current_timestamp + 0.1, + text=word, + probability=0.95 + ).with_offset( + self.global_time_offset + ) + timestamped_words.append(timestamp_entry) + return timestamped_words \ No newline at end of file