From 55e08474f3aa8a5077c272e733cf22b3db31191e Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 16 Aug 2025 23:06:16 +0200 Subject: [PATCH] recycle backend in simulstreaming thanks to new remove hooks function --- whisperlivekit/simul_whisper/backend.py | 74 ++++++++++++++----- whisperlivekit/simul_whisper/simul_whisper.py | 20 +++-- 2 files changed, 70 insertions(+), 24 deletions(-) diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 90fb900..fa550da 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -7,6 +7,8 @@ from whisperlivekit.timed_objects import ASRToken, Transcript from whisperlivekit.warmup import load_file from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from .whisper import load_model, tokenizer +from .whisper.audio import TOKENS_PER_SECOND + import os import gc logger = logging.getLogger(__name__) @@ -21,6 +23,8 @@ except ImportError as e: """SimulStreaming dependencies are not available. Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""") +# TOO_MANY_REPETITIONS = 3 + class SimulStreamingOnlineProcessor: SAMPLING_RATE = 16000 @@ -33,33 +37,41 @@ class SimulStreamingOnlineProcessor: self.asr = asr self.logfile = logfile self.is_last = False - self.beg = 0.0 self.end = 0.0 self.cumulative_audio_duration = 0.0 self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] - model = asr.get_new_model_instance() - self.model = PaddedAlignAttWhisper( - cfg=asr.cfg, - loaded_model=model) + self.load_new_backend() if asr.tokenizer: self.model.tokenizer = asr.tokenizer - def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None): + def load_new_backend(self): + model = self.asr.get_new_model_instance() + self.model = PaddedAlignAttWhisper( + cfg=self.asr.cfg, + loaded_model=model) + + def insert_silence(self, silence_duration): + """ + If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame + """ + if silence_duration < 5: + gap_silence = torch.zeros(int(16000*min(silence_duration, 1.0))) + self.model.insert_audio(gap_silence) + self.model.last_attend_frame += int(TOKENS_PER_SECOND * (min(silence_duration, 1.0) - 1.0)) + else: + self.model.refresh_segment(complete=True) + self.model.last_attend_frame += int(TOKENS_PER_SECOND * silence_duration) + + + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): """Append an audio chunk to be processed by SimulStreaming.""" # Convert numpy array to torch tensor audio_tensor = torch.from_numpy(audio).float() - - # Update timing - chunk_duration = len(audio) / self.SAMPLING_RATE - self.cumulative_audio_duration += chunk_duration - - if audio_stream_end_time is not None: - self.end = audio_stream_end_time - else: - self.end = self.cumulative_audio_duration + self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. self.model.insert_audio(audio_tensor) def get_buffer(self): @@ -81,7 +93,7 @@ class SimulStreamingOnlineProcessor: frames = [p["most_attended_frames"][0] for p in pr] tokens = tokens.copy() ret = [] - for sw,st in zip(split_words,split_tokens): + for sw,st in zip(split_words, split_tokens): b = None for stt in st: t,f = tokens.pop(0), frames.pop(0) @@ -116,7 +128,29 @@ class SimulStreamingOnlineProcessor: probability=0.95 # fake prob. Maybe we can extract it from the model? ) new_tokens.append(token) - self.committed.extend(new_tokens) + + # identical_tokens = 0 + # n_new_tokens = len(new_tokens) + # if n_new_tokens: + + self.committed.extend(new_tokens) + + # if token in self.committed: + # pos = len(self.committed) - 1 - self.committed[::-1].index(token) + # if pos: + # for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens): + # commited_segment = self.committed[i:i+n_new_tokens] + # if commited_segment == new_tokens: + # identical_segments +=1 + # if identical_tokens >= TOO_MANY_REPETITIONS: + # logger.warning('Too many repetition, model is stuck. Load a new one') + # self.committed = self.committed[:i] + # self.load_new_backend() + # return [], self.end + + # pos = self.committed.rindex(token) + + return new_tokens, self.end @@ -137,10 +171,11 @@ class SimulStreamingOnlineProcessor: def __del__(self): # free the model and add a new model to stack. - del self.model + # del self.model gc.collect() torch.cuda.empty_cache() - self.asr.new_model_to_stack() + # self.asr.new_model_to_stack() + self.model.remove_hooks() class SimulStreamingASR(): """SimulStreaming backend with AlignAtt policy.""" @@ -231,6 +266,7 @@ class SimulStreamingASR(): self.models.append(self.load_model()) new_model = self.models.pop() return new_model + # self.models[0] def new_model_to_stack(self): self.models.append(self.load_model()) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index a8188d9..aa3d794 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -56,6 +56,7 @@ class PaddedAlignAttWhisper: self.max_text_len = self.model.dims.n_text_ctx self.num_decoder_layers = len(self.model.decoder.blocks) self.cfg = cfg + self.l_hooks = [] # model to detect end-of-word boundary at the end of the segment self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, @@ -69,7 +70,8 @@ class PaddedAlignAttWhisper: t = F.softmax(net_output[1], dim=-1) self.dec_attns.append(t.squeeze(0)) for b in self.model.decoder.blocks: - b.cross_attn.register_forward_hook(layer_hook) + hook = b.cross_attn.register_forward_hook(layer_hook) + self.l_hooks.append(hook) self.kv_cache = {} def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor): @@ -82,10 +84,13 @@ class PaddedAlignAttWhisper: return self.kv_cache[module.cache_id] for i,b in enumerate(self.model.decoder.blocks): - b.attn.key.register_forward_hook(kv_hook) - b.attn.value.register_forward_hook(kv_hook) - b.cross_attn.key.register_forward_hook(kv_hook) - b.cross_attn.value.register_forward_hook(kv_hook) + hooks = [ + b.attn.key.register_forward_hook(kv_hook), + b.attn.value.register_forward_hook(kv_hook), + b.cross_attn.key.register_forward_hook(kv_hook), + b.cross_attn.value.register_forward_hook(kv_hook), + ] + self.l_hooks.extend(hooks) self.align_source = {} self.num_align_heads = 0 @@ -139,6 +144,11 @@ class PaddedAlignAttWhisper: self.inference.kv_cache = self.kv_cache self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) + + def remove_hooks(self): + print('remove hook') + for hook in self.l_hooks: + hook.remove() def create_tokenizer(self, language=None): self.tokenizer = tokenizer.get_tokenizer(