diff --git a/whisperlivekit/simul_whisper/whisper/model.py b/whisperlivekit/simul_whisper/whisper/model.py index e537447..7fb887e 100644 --- a/whisperlivekit/simul_whisper/whisper/model.py +++ b/whisperlivekit/simul_whisper/whisper/model.py @@ -79,15 +79,18 @@ def disable_sdpa(): class MultiHeadAttention(nn.Module): - use_sdpa = True + use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks - def __init__(self, n_state: int, n_head: int): + def __init__(self, n_state: int, n_head: int, cache_id: str = ""): super().__init__() self.n_head = n_head self.query = Linear(n_state, n_state) self.key = Linear(n_state, n_state, bias=False) self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) + self.cache_id = cache_id + self.key.cache_id = f"{cache_id}_key" + self.value.cache_id = f"{cache_id}_value" def forward( self, @@ -140,14 +143,14 @@ class MultiHeadAttention(nn.Module): class ResidualAttentionBlock(nn.Module): - def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""): super().__init__() - self.attn = MultiHeadAttention(n_state, n_head) + self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn") self.attn_ln = LayerNorm(n_state) self.cross_attn = ( - MultiHeadAttention(n_state, n_head) if cross_attention else None + MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None ) self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None @@ -181,7 +184,7 @@ class AudioEncoder(nn.Module): self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( - [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + [ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)] ) self.ln_post = LayerNorm(n_state) @@ -215,8 +218,8 @@ class TextDecoder(nn.Module): self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ - ResidualAttentionBlock(n_state, n_head, cross_attention=True) - for _ in range(n_layer) + ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}") + for i in range(n_layer) ] ) self.ln = LayerNorm(n_state)