in MultiHeadAttention and ResidualAttentionBlock include cache_id for compatibility with simulstreaming code

This commit is contained in:
Quentin Fuxa
2025-08-02 13:16:58 +02:00
parent 687e3dd5e2
commit 4cfed6e98e

View File

@@ -79,15 +79,18 @@ def disable_sdpa():
class MultiHeadAttention(nn.Module):
use_sdpa = True
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
def __init__(self, n_state: int, n_head: int):
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.cache_id = cache_id
self.key.cache_id = f"{cache_id}_key"
self.value.cache_id = f"{cache_id}_value"
def forward(
self,
@@ -140,14 +143,14 @@ class MultiHeadAttention(nn.Module):
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -181,7 +184,7 @@ class AudioEncoder(nn.Module):
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
@@ -215,8 +218,8 @@ class TextDecoder(nn.Module):
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
for i in range(n_layer)
]
)
self.ln = LayerNorm(n_state)