From 8c799fa4d1751ffc529315ed8dcb57cc6f9bc27b Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 11 Feb 2026 22:10:00 +0100 Subject: [PATCH] fix simulstreaming vram leak: cap cross-attn accumulation + token budget fixes #283, fixes #275 - accumulated_cross_attns was growing unboundedly during decoding loop, using up to ~5GB for repetition loops. now capped to rolling window of 16 - max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50) instead of actual text token rate (~15/s), allowing 10-40x too many decoding steps - removed unused torch.cat on early return path - removed dead self.committed/last_result_tokens lists (never read) - same fixes applied to mlx variant --- whisperlivekit/simul_whisper/backend.py | 3 --- whisperlivekit/simul_whisper/mlx/simul_whisper.py | 6 +++++- whisperlivekit/simul_whisper/simul_whisper.py | 11 ++++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index fb00f8d..33b1b81 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor: self.logfile = logfile self.end = 0.0 self.buffer = [] - self.committed: List[ASRToken] = [] - self.last_result_tokens: List[ASRToken] = [] self.model = self._create_alignatt() if asr.tokenizer: @@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor: self.buffer.extend(timestamped_words) return [], self.end - self.committed.extend(timestamped_words) self.buffer = [] return timestamped_words, self.end except Exception as e: diff --git a/whisperlivekit/simul_whisper/mlx/simul_whisper.py b/whisperlivekit/simul_whisper/mlx/simul_whisper.py index 4e2ba14..50b327c 100644 --- a/whisperlivekit/simul_whisper/mlx/simul_whisper.py +++ b/whisperlivekit/simul_whisper/mlx/simul_whisper.py @@ -532,7 +532,9 @@ class MLXAlignAtt: accumulated_cross_attns = [] audio_duration_s = self.segments_len() - max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) + # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50) + # is the mel-frame rate and was causing 10-40x over-allocation on repetition loops. + max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5)) tokens_produced_this_chunk = 0 while not completed and current_tokens.shape[1] < self.max_text_len: @@ -558,6 +560,8 @@ class MLXAlignAtt: mx.eval(logits) accumulated_cross_attns.append(cross_qk) + if len(accumulated_cross_attns) > 16: + accumulated_cross_attns = accumulated_cross_attns[-16:] if new_segment and self.tokenizer.no_speech is not None: probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 174e806..0ea15a7 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -390,7 +390,6 @@ class AlignAtt: return [] if not self._apply_minseglen(): logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") - input_segments = torch.cat(self.state.segments, dim=0) return [] # input_segments is concatenation of audio, it's one array @@ -485,7 +484,9 @@ class AlignAtt: accumulated_cross_attns = [] audio_duration_s = self.segments_len() - max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50 + # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50) + # is the mel-frame rate and was causing 10-40x over-allocation on repetition loops. + max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5)) tokens_produced_this_chunk = 0 while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens @@ -506,8 +507,12 @@ class AlignAtt: result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) logits, cross_attns = result - # Accumulate cross-attention from this forward pass + # Accumulate cross-attention from this forward pass (rolling window to + # bound VRAM — only the last entry matters for alignment, and the + # median_filter kernel is 7, so 16 entries is more than enough). accumulated_cross_attns.append(cross_attns) + if len(accumulated_cross_attns) > 16: + accumulated_cross_attns = accumulated_cross_attns[-16:] if new_segment and self.tokenizer.no_speech is not None: probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)