From 8a5e2adb1e9971c9717aba14c9a7f472efaf9f3b Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Thu, 31 Jul 2025 16:25:34 +0200 Subject: [PATCH] simulstreaming: fixes token handling during warm-up phase --- whisperlivekit/simul_whisper/simul_whisper.py | 7 ++++--- whisperlivekit/whisper_streaming_custom/backends.py | 1 - whisperlivekit/whisper_streaming_custom/online_asr.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 226c070..250798f 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -261,9 +261,10 @@ class PaddedAlignAttWhisper: segments_len -= removed_len self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) self.segments = self.segments[1:] - logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}") - self.context.append_token_ids(self.tokens[1][0,:]) - self.tokens = [self.initial_tokens] + self.tokens[2:] + if len(self.tokens) > 1: # When warming up, we can have a too long segments_len while not having any tokens yet + self.context.append_token_ids(self.tokens[1][0,:]) + self.context.append_token_ids(self.tokens[1][0,:]) + self.tokens = [self.initial_tokens] + self.tokens[2:] return removed_len diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index eea017d..d6ad639 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -481,7 +481,6 @@ class SimulStreamingASR(ASRBase): try: if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio).float() - print(audio) self.model.insert_audio(audio) self.model.infer(True) self.model.refresh_segment(complete=True) diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index f5e3116..7f2c65c 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -680,8 +680,7 @@ class SimulStreamingOnlineProcessor: except Exception as e: - logger.error(f"SimulStreaming processing error: {e}") - logger.error(f"Error details: {type(e).__name__}: {str(e)}") + logger.exception(f"SimulStreaming processing error: {e}") return [], self.end def finish(self) -> Tuple[List[ASRToken], float]: