mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275 - accumulated_cross_attns was growing unboundedly during decoding loop, using up to ~5GB for repetition loops. now capped to rolling window of 16 - max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50) instead of actual text token rate (~15/s), allowing 10-40x too many decoding steps - removed unused torch.cat on early return path - removed dead self.committed/last_result_tokens lists (never read) - same fixes applied to mlx variant
This commit is contained in:
@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
if asr.tokenizer:
|
||||
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
except Exception as e:
|
||||
|
||||
@@ -532,7 +532,9 @@ class MLXAlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||
@@ -558,6 +560,8 @@ class MLXAlignAtt:
|
||||
mx.eval(logits)
|
||||
|
||||
accumulated_cross_attns.append(cross_qk)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||
|
||||
@@ -390,7 +390,6 @@ class AlignAtt:
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.state.segments, dim=0)
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
@@ -485,7 +484,9 @@ class AlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
@@ -506,8 +507,12 @@ class AlignAtt:
|
||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||
logits, cross_attns = result
|
||||
|
||||
# Accumulate cross-attention from this forward pass
|
||||
# Accumulate cross-attention from this forward pass (rolling window to
|
||||
# bound VRAM — only the last entry matters for alignment, and the
|
||||
# median_filter kernel is 7, so 16 entries is more than enough).
|
||||
accumulated_cross_attns.append(cross_attns)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user