recycle backend in simulstreaming thanks to new remove hooks function

This commit is contained in:
Quentin Fuxa
2025-08-16 23:06:16 +02:00
parent 28bdc52e1d
commit 55e08474f3
2 changed files with 70 additions and 24 deletions

View File

@@ -7,6 +7,8 @@ from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
logger = logging.getLogger(__name__)
@@ -21,6 +23,8 @@ except ImportError as e:
"""SimulStreaming dependencies are not available.
Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]".""")
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
@@ -33,33 +37,41 @@ class SimulStreamingOnlineProcessor:
self.asr = asr
self.logfile = logfile
self.is_last = False
self.beg = 0.0
self.end = 0.0
self.cumulative_audio_duration = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
model = asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=model)
self.load_new_backend()
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
def load_new_backend(self):
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model)
def insert_silence(self, silence_duration):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
if silence_duration < 5:
gap_silence = torch.zeros(int(16000*min(silence_duration, 1.0)))
self.model.insert_audio(gap_silence)
self.model.last_attend_frame += int(TOKENS_PER_SECOND * (min(silence_duration, 1.0) - 1.0))
else:
self.model.refresh_segment(complete=True)
self.model.last_attend_frame += int(TOKENS_PER_SECOND * silence_duration)
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.end = audio_stream_end_time
else:
self.end = self.cumulative_audio_duration
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor)
def get_buffer(self):
@@ -81,7 +93,7 @@ class SimulStreamingOnlineProcessor:
frames = [p["most_attended_frames"][0] for p in pr]
tokens = tokens.copy()
ret = []
for sw,st in zip(split_words,split_tokens):
for sw,st in zip(split_words, split_tokens):
b = None
for stt in st:
t,f = tokens.pop(0), frames.pop(0)
@@ -116,7 +128,29 @@ class SimulStreamingOnlineProcessor:
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.committed.extend(new_tokens)
# identical_tokens = 0
# n_new_tokens = len(new_tokens)
# if n_new_tokens:
self.committed.extend(new_tokens)
# if token in self.committed:
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
# if pos:
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
# commited_segment = self.committed[i:i+n_new_tokens]
# if commited_segment == new_tokens:
# identical_segments +=1
# if identical_tokens >= TOO_MANY_REPETITIONS:
# logger.warning('Too many repetition, model is stuck. Load a new one')
# self.committed = self.committed[:i]
# self.load_new_backend()
# return [], self.end
# pos = self.committed.rindex(token)
return new_tokens, self.end
@@ -137,10 +171,11 @@ class SimulStreamingOnlineProcessor:
def __del__(self):
# free the model and add a new model to stack.
del self.model
# del self.model
gc.collect()
torch.cuda.empty_cache()
self.asr.new_model_to_stack()
# self.asr.new_model_to_stack()
self.model.remove_hooks()
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
@@ -231,6 +266,7 @@ class SimulStreamingASR():
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())

View File

@@ -56,6 +56,7 @@ class PaddedAlignAttWhisper:
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
self.l_hooks = []
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
@@ -69,7 +70,8 @@ class PaddedAlignAttWhisper:
t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks:
b.cross_attn.register_forward_hook(layer_hook)
hook = b.cross_attn.register_forward_hook(layer_hook)
self.l_hooks.append(hook)
self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
@@ -82,10 +84,13 @@ class PaddedAlignAttWhisper:
return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks):
b.attn.key.register_forward_hook(kv_hook)
b.attn.value.register_forward_hook(kv_hook)
b.cross_attn.key.register_forward_hook(kv_hook)
b.cross_attn.value.register_forward_hook(kv_hook)
hooks = [
b.attn.key.register_forward_hook(kv_hook),
b.attn.value.register_forward_hook(kv_hook),
b.cross_attn.key.register_forward_hook(kv_hook),
b.cross_attn.value.register_forward_hook(kv_hook),
]
self.l_hooks.extend(hooks)
self.align_source = {}
self.num_align_heads = 0
@@ -139,6 +144,11 @@ class PaddedAlignAttWhisper:
self.inference.kv_cache = self.kv_cache
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
def remove_hooks(self):
print('remove hook')
for hook in self.l_hooks:
hook.remove()
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(