diff --git a/README.md b/README.md index 14314b6..b867199 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,6 @@ async def websocket_endpoint(websocket: WebSocket): | `--init-prompt` | Initial prompt for the model | `None` | | `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--max-context-tokens` | Maximum context tokens | `None` | -| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 33ced0d..88573d1 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -103,7 +103,6 @@ class TranscriptionEngine: "init_prompt": None, "static_init_prompt": None, "max_context_tokens": None, - "preload_model_count": 1, } simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index d342645..7f67a97 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -296,14 +296,6 @@ def parse_args(): help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.", ) - simulstreaming_group.add_argument( - "--preload-model-count", - type=int, - default=1, - dest="preload_model_count", - help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).", - ) - simulstreaming_group.add_argument( "--nllb-backend", type=str, diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 760ecfd..7fdf09f 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -49,20 +49,19 @@ class SimulStreamingOnlineProcessor: self.buffer = [] self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] - self.load_new_backend() + self.load_new_alignatt_instance() - #can be moved if asr.tokenizer: self.model.tokenizer = asr.tokenizer - def load_new_backend(self): - model = self.asr.get_new_model_instance() + def load_new_alignatt_instance(self): + """Initialize AlignAtt decoder using the shared model.""" self.model = AlignAtt( cfg=self.asr.cfg, - loaded_model=model, + loaded_model=self.asr.shared_model, mlx_encoder=self.asr.mlx_encoder, fw_encoder=self.asr.fw_encoder, - ) + ) def start_silence(self): tokens, processed_upto = self.process_iter(is_last=True) @@ -70,7 +69,10 @@ class SimulStreamingOnlineProcessor: def end_silence(self, silence_duration, offset): """ - If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame + Handle silence period. + + If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear. + Otherwise, insert a small silence and shift the last_attend_frame. """ self.end += silence_duration long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE @@ -83,21 +85,20 @@ class SimulStreamingOnlineProcessor: self.model.refresh_segment(complete=True) self.model.global_time_offset = silence_duration + offset - - def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): """Append an audio chunk to be processed by SimulStreaming.""" # Convert numpy array to torch tensor audio_tensor = torch.from_numpy(audio).float() - self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend. + self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior self.model.insert_audio(audio_tensor) def new_speaker(self, change_speaker: ChangeSpeaker): - self.process_iter(is_last=True) - self.model.refresh_segment(complete=True) - self.model.speaker = change_speaker.speaker - self.global_time_offset = change_speaker.start + """Handle speaker change event.""" + self.process_iter(is_last=True) + self.model.refresh_segment(complete=True) + self.model.speaker = change_speaker.speaker + self.model.global_time_offset = change_speaker.start def get_buffer(self): concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') @@ -122,8 +123,6 @@ class SimulStreamingOnlineProcessor: self.committed.extend(timestamped_words) self.buffer = [] return timestamped_words, self.end - - except Exception as e: logger.exception(f"SimulStreaming processing error: {e}") return [], self.end @@ -139,12 +138,8 @@ class SimulStreamingOnlineProcessor: logger.exception(f"SimulStreaming warmup failed: {e}") def __del__(self): - # free the model and add a new model to stack. - # del self.model gc.collect() torch.cuda.empty_cache() - # self.asr.new_model_to_stack() - self.model.remove_hooks() class SimulStreamingASR(): """SimulStreaming backend with AlignAtt policy.""" @@ -229,10 +224,7 @@ class SimulStreamingASR(): self.tokenizer = self.set_translate_task() else: self.tokenizer = None - - - - + self.mlx_encoder, self.fw_encoder = None, None if self.encoder_backend == "mlx-whisper": print('Simulstreaming will use MLX whisper to increase encoding speed.') @@ -256,8 +248,7 @@ class SimulStreamingASR(): device='auto', compute_type='auto', ) - - self.models = [self.load_model() for i in range(self.preload_model_count)] + self.shared_model = self.load_model() def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): @@ -306,11 +297,11 @@ class SimulStreamingASR(): download_root=self.model_path, decoder_only=self.fast_encoder, custom_alignment_heads=self.custom_alignment_heads - ) + ) warmup_audio = load_file(self.warmup_file) if warmup_audio is not None: warmup_audio = torch.from_numpy(warmup_audio).float() - if self.fast_encoder: + if self.fast_encoder: temp_model = AlignAtt( cfg=self.cfg, loaded_model=whisper_model, @@ -318,27 +309,9 @@ class SimulStreamingASR(): fw_encoder=self.fw_encoder, ) temp_model.warmup(warmup_audio) - temp_model.remove_hooks() else: - # For standard encoder, use the original transcribe warmup - warmup_audio = load_file(self.warmup_file) whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None) return whisper_model - - def get_new_model_instance(self): - """ - SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers. - Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory. - """ - if len(self.models) == 0: - self.models.append(self.load_model()) - new_model = self.models.pop() - return new_model - # self.models[0] - - def new_model_to_stack(self): - self.models.append(self.load_model()) - def set_translate_task(self): """Set up translation task.""" diff --git a/whisperlivekit/simul_whisper/beam.py b/whisperlivekit/simul_whisper/beam.py index fba600d..27cec0b 100644 --- a/whisperlivekit/simul_whisper/beam.py +++ b/whisperlivekit/simul_whisper/beam.py @@ -1,18 +1,32 @@ +from torch import Tensor + from whisperlivekit.whisper.decoding import PyTorchInference -# extention of PyTorchInference for beam search class BeamPyTorchInference(PyTorchInference): + """Extension of PyTorchInference for beam search with cross-attention support.""" - def _kv_modules(self): - key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks] - value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks] - return key_modules + value_modules + def _kv_cache_ids(self): + """Get cache_id strings for self-attention key/value modules.""" + key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks] + value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks] + return key_ids + value_ids def rearrange_kv_cache(self, source_indices): if source_indices != list(range(len(source_indices))): - for module_cache_id in self._kv_modules(): - self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach() - from torch import Tensor - def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: - return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) \ No newline at end of file + for cache_id in self._kv_cache_ids(): + if cache_id in self.kv_cache: + self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach() + + def logits( + self, + tokens: Tensor, + audio_features: Tensor, + return_cross_attn: bool = False, + ): + """Get logits, optionally returning cross-attention weights.""" + return self.model.decoder( + tokens, audio_features, + kv_cache=self.kv_cache, + return_cross_attn=return_cross_attn, + ) \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/decoder_state.py b/whisperlivekit/simul_whisper/decoder_state.py new file mode 100644 index 0000000..bbab43b --- /dev/null +++ b/whisperlivekit/simul_whisper/decoder_state.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple +import torch + + +@dataclass +class DecoderState: + + kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict) + + tokenizer: Any = None + detected_language: Optional[str] = None + reset_tokenizer_to_auto_next_call: bool = False + + tokens: List[torch.Tensor] = field(default_factory=list) + initial_tokens: Optional[torch.Tensor] = None + initial_token_length: int = 0 + sot_index: int = 0 + + align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) + num_align_heads: int = 0 + + segments: List[torch.Tensor] = field(default_factory=list) + + context: Any = None + + pending_incomplete_tokens: List[int] = field(default_factory=list) + + global_time_offset: float = 0.0 + cumulative_time_offset: float = 0.0 + first_timestamp: Optional[float] = None + last_attend_frame: int = 0 + + speaker: int = -1 + log_segments: int = 0 + + CIFLinear: Optional[torch.nn.Module] = None + always_fire: bool = False + never_fire: bool = False + + suppress_tokens_fn: Any = None + + token_decoder: Any = None + decoder_type: str = "greedy" + + inference: Any = None + + def clean_cache(self): + """Clean the kv_cache after each inference step.""" + self.kv_cache = {} + if self.decoder_type == "beam" and self.inference is not None: + self.inference.kv_cache = self.kv_cache + if self.token_decoder is not None: + self.token_decoder.reset() + + def reset(self, rewind_threshold: int = 200): + """ + Reset transient state for a new segment. + + Args: + rewind_threshold: Value for resetting last_attend_frame + """ + self.last_attend_frame = -rewind_threshold + self.cumulative_time_offset = 0.0 + self.pending_incomplete_tokens = [] + self.log_segments += 1 + + def full_reset(self, rewind_threshold: int = 200): + """ + Full reset including audio segments and tokens. + + Args: + rewind_threshold: Value for resetting last_attend_frame + """ + self.reset(rewind_threshold) + self.segments = [] + self.tokens = [] + self.kv_cache = {} + self.first_timestamp = None + diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index e207249..55a9ce0 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -1,6 +1,7 @@ import logging import os from time import time +from typing import List, Optional, Tuple import numpy as np import torch @@ -20,6 +21,7 @@ from whisperlivekit.whisper.timing import median_filter from ..timed_objects import PUNCTUATION_MARKS from .beam import BeamPyTorchInference from .config import AlignAttConfig +from .decoder_state import DecoderState from .eow_detection import fire_at_boundary, load_cif from .token_buffer import TokenBuffer @@ -53,6 +55,30 @@ def load_coreml_encoder(): class AlignAtt: + """ + Alignment-based Attention decoder for SimulStreaming. + + This class is now hookless - the model can be shared across multiple + sessions, with each session maintaining its own DecoderState. + """ + + # Property accessors for backward compatibility + @property + def speaker(self): + return self.state.speaker + + @speaker.setter + def speaker(self, value): + self.state.speaker = value + + @property + def global_time_offset(self): + return self.state.global_time_offset + + @global_time_offset.setter + def global_time_offset(self, value): + self.state.global_time_offset = value + def __init__( self, cfg: AlignAttConfig, @@ -60,8 +86,7 @@ class AlignAtt: mlx_encoder=None, fw_encoder=None, ) -> None: - self.log_segments = 0 - + # Shared model reference (can be shared across sessions) self.model = loaded_model self.mlx_encoder = mlx_encoder self.fw_encoder = fw_encoder @@ -75,119 +100,89 @@ class AlignAtt: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info(f"Model dimensions: {self.model.dims}") - self.speaker = -1 self.decode_options = DecodingOptions( - language = cfg.language, - without_timestamps = True, + language=cfg.language, + without_timestamps=True, task=cfg.task ) self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual - self.create_tokenizer(cfg.language if cfg.language != "auto" else None) - # self.create_tokenizer('en') - self.detected_language = cfg.language if cfg.language != "auto" else None - self.global_time_offset = 0.0 - self.reset_tokenizer_to_auto_next_call = False self.max_text_len = self.model.dims.n_text_ctx self.num_decoder_layers = len(self.model.decoder.blocks) self.cfg = cfg - self.l_hooks = [] - - # model to detect end-of-word boundary at the end of the segment - self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg, - n_audio_state=self.model.dims.n_audio_state, - device=self.model.device) - - # install hooks to access encoder-decoder attention - self.dec_attns = [] - def layer_hook(module, net_input, net_output): - # net_output[1]: B*num_head*token_len*audio_len - t = F.softmax(net_output[1], dim=-1) - self.dec_attns.append(t.squeeze(0)) - for b in self.model.decoder.blocks: - hook = b.cross_attn.register_forward_hook(layer_hook) - self.l_hooks.append(hook) - - self.kv_cache = {} - def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor): - if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len: - # save as-is, for the first token or cross attention - self.kv_cache[module.cache_id] = net_output - else: - x = self.kv_cache[module.cache_id] - self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach() - return self.kv_cache[module.cache_id] - - for i,b in enumerate(self.model.decoder.blocks): - hooks = [ - b.attn.key.register_forward_hook(kv_hook), - b.attn.value.register_forward_hook(kv_hook), - b.cross_attn.key.register_forward_hook(kv_hook), - b.cross_attn.value.register_forward_hook(kv_hook), - ] - self.l_hooks.extend(hooks) - - self.align_source = {} - self.num_align_heads = 0 - for layer_rank, head_id in self.model.alignment_heads.indices().T: - layer_rank = layer_rank.item() - heads = self.align_source.get(layer_rank, []) - heads.append((self.num_align_heads, head_id.item())) - self.align_source[layer_rank] = heads - self.num_align_heads += 1 - - - # tokens to be suppressed from decoding, to prevent hallucinations - suppress_tokens = [ - self.tokenizer.transcribe, - self.tokenizer.translate, - self.tokenizer.sot, - self.tokenizer.sot_prev, - self.tokenizer.sot_lm, - # self.tokenizer.eot - self.tokenizer.no_timestamps, # added by DM - ] + list(self.tokenizer.all_language_tokens) # added by DM - if self.tokenizer.no_speech is not None: - suppress_tokens.append(self.tokenizer.no_speech) - suppress_tokens = tuple(sorted(set(suppress_tokens))) - logger.debug(f"Suppress tokens: {suppress_tokens}") - sup_tokens = SuppressTokens(suppress_tokens) - self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None) - # blank tokens are suppresed for new segments near the line 334 - - # it's going to be regenerated after lang id - self.segments = [] - self.init_tokens() - - self.last_attend_frame = -self.cfg.rewind_threshold - self.cumulative_time_offset = 0.0 - self.first_timestamp = None if self.cfg.max_context_tokens is None: self.max_context_tokens = self.max_text_len else: self.max_context_tokens = self.cfg.max_context_tokens + + # Initialize per-session state + self.state = DecoderState() + self._init_state(cfg) + + def _init_state(self, cfg: AlignAttConfig): + """Initialize the per-session decoder state.""" + # Create tokenizer + self.create_tokenizer(cfg.language if cfg.language != "auto" else None) + self.state.tokenizer = self.tokenizer + self.state.detected_language = cfg.language if cfg.language != "auto" else None + + # Timing state + self.state.global_time_offset = 0.0 + self.state.last_attend_frame = -cfg.rewind_threshold + self.state.speaker = -1 + + # CIF helpers for end-of-word boundary detection + self.state.CIFLinear, self.state.always_fire, self.state.never_fire = load_cif( + cfg, + n_audio_state=self.model.dims.n_audio_state, + device=self.model.device + ) + + # Build alignment source mapping from model's alignment_heads + self.state.align_source = {} + self.state.num_align_heads = 0 + for layer_rank, head_id in self.model.alignment_heads.indices().T: + layer_rank = layer_rank.item() + heads = self.state.align_source.get(layer_rank, []) + heads.append((self.state.num_align_heads, head_id.item())) + self.state.align_source[layer_rank] = heads + self.state.num_align_heads += 1 + + # Build suppress tokens function + suppress_tokens = [ + self.tokenizer.transcribe, + self.tokenizer.translate, + self.tokenizer.sot, + self.tokenizer.sot_prev, + self.tokenizer.sot_lm, + self.tokenizer.no_timestamps, + ] + list(self.tokenizer.all_language_tokens) + if self.tokenizer.no_speech is not None: + suppress_tokens.append(self.tokenizer.no_speech) + suppress_tokens = tuple(sorted(set(suppress_tokens))) + logger.debug(f"Suppress tokens: {suppress_tokens}") + sup_tokens = SuppressTokens(suppress_tokens) + self.state.suppress_tokens_fn = lambda logits: sup_tokens.apply(logits, None) + + # Initialize tokens + self.init_tokens() self.init_context() - # decoder type: greedy or beam + # Set up decoder type + self.state.decoder_type = cfg.decoder_type if cfg.decoder_type == "greedy": logger.info("Using greedy decoder") - self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot) - self.decoder_type = "greedy" - + self.state.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot) elif cfg.decoder_type == "beam": - self.decoder_type = "beam" - self.inference = BeamPyTorchInference(self.model, self.initial_token_length) - self.inference.kv_cache = self.kv_cache - - self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size) - - # Tokens to carry over to next chunk for incomplete UTF-8 characters - self.pending_incomplete_tokens = [] - - def remove_hooks(self): - for hook in self.l_hooks: - hook.remove() + logger.info("Using beam decoder") + self.state.inference = BeamPyTorchInference(self.model, self.state.initial_token_length) + self.state.inference.kv_cache = self.state.kv_cache + self.state.token_decoder = BeamSearchDecoder( + inference=self.state.inference, + eot=self.tokenizer.eot, + beam_size=cfg.beam_size + ) def warmup(self, audio): try: @@ -205,96 +200,100 @@ class AlignAtt: num_languages=self.model.num_languages, task=self.decode_options.task ) + self.state.tokenizer = self.tokenizer def init_context(self): kw = {'tokenizer': self.tokenizer, 'device': self.model.device, 'prefix_token_ids': [self.tokenizer.sot_prev]} - self.context = TokenBuffer.empty(**kw) + self.state.context = TokenBuffer.empty(**kw) if self.cfg.static_init_prompt is not None: - self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw) + self.state.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw) if self.cfg.init_prompt is not None: - self.context.text += self.cfg.init_prompt + self.state.context.text += self.cfg.init_prompt def init_tokens(self): - logger.debug(f"init tokens, {len(self.segments)}") + logger.debug(f"init tokens, {len(self.state.segments)}") # init tokens (mandatory prompt) - self.initial_tokens = torch.tensor( + self.state.initial_tokens = torch.tensor( self.tokenizer.sot_sequence_including_notimestamps, dtype=torch.long, device=self.model.device).unsqueeze(0) - self.initial_token_length = self.initial_tokens.shape[1] - self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) -# self.segments = [] - logger.debug(f"init tokens after, {len(self.segments)}") - self.tokens = [self.initial_tokens] + self.state.initial_token_length = self.state.initial_tokens.shape[1] + self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) + logger.debug(f"init tokens after, {len(self.state.segments)}") + self.state.tokens = [self.state.initial_tokens] def trim_context(self): logger.info("Trimming context") - c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids) -# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}") - logger.info(f"Context text: {self.context.as_text()}") -# logger.debug(f"Context tensor: {self.context.as_tensor()}") - l = sum(t.shape[1] for t in self.tokens) + c -# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") + c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids) + logger.info(f"Context text: {self.state.context.as_text()}") + l = sum(t.shape[1] for t in self.state.tokens) + c if self.cfg.static_init_prompt is None: after = 0 else: after = len(self.cfg.static_init_prompt) -# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") while c > self.max_context_tokens or l > self.max_text_len - 20: - t = self.context.trim_words(after=after) + t = self.state.context.trim_words(after=after) l -= t c -= t logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") if t == 0: break -# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") - logger.info(f"Context after trim: {self.context.text} (len: {l})") + logger.info(f"Context after trim: {self.state.context.text} (len: {l})") - def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor: - if self.cfg.decoder_type == "greedy": - logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) + def logits( + self, + tokens: torch.Tensor, + audio_features: torch.Tensor, + return_cross_attn: bool = False + ): + """Get logits from decoder, optionally returning cross-attention weights.""" + if self.state.decoder_type == "greedy": + return self.model.decoder( + tokens, audio_features, + kv_cache=self.state.kv_cache, + return_cross_attn=return_cross_attn + ) else: logger.debug(f"Logits shape: {tokens.shape}") - logit = self.inference.logits(tokens, audio_features) - return logit + return self.state.inference.logits( + tokens, audio_features, + return_cross_attn=return_cross_attn + ) def refresh_segment(self, complete=False): - logger.debug("Refreshing segment:") self.init_tokens() - self.last_attend_frame = -self.cfg.rewind_threshold - # self.detected_language = None - self.cumulative_time_offset = 0.0 + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 self.init_context() - logger.debug(f"Context: {self.context}") - if not complete and len(self.segments) > 2: - self.segments = self.segments[-2:] + logger.debug(f"Context: {self.state.context}") + if not complete and len(self.state.segments) > 2: + self.state.segments = self.state.segments[-2:] else: logger.debug("removing all segments.") - self.segments = [] - self.log_segments += 1 - - self.pending_incomplete_tokens = [] + self.state.segments = [] + self.state.log_segments += 1 + self.state.pending_incomplete_tokens = [] def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor): - if self.always_fire: return True - if self.never_fire: return False - return fire_at_boundary(chunked_encoder_feature, self.CIFLinear) - + if self.state.always_fire: + return True + if self.state.never_fire: + return False + return fire_at_boundary(chunked_encoder_feature, self.state.CIFLinear) def _current_tokens(self): - - toks = self.tokens + toks = self.state.tokens # very first infer: duplicate start of seq to beam_size if toks[0].shape[0] == 1: - toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0) + toks[0] = toks[0].repeat_interleave(self.cfg.beam_size, dim=0) - if not self.context.is_empty(): - context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device) + if not self.state.context.is_empty(): + context_toks = self.state.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device) toks = [context_toks] + toks # make it one tensor @@ -314,7 +313,7 @@ class AlignAtt: ### audio buffer def segments_len(self): - segments_len = sum(s.shape[0] for s in self.segments) / 16000 + segments_len = sum(s.shape[0] for s in self.state.segments) / 16000 return segments_len def _apply_minseglen(self): @@ -327,42 +326,36 @@ class AlignAtt: def insert_audio(self, segment=None): if segment is not None: - self.segments.append(segment) + self.state.segments.append(segment) removed_len = 0 # len of audio is bigger than buffer_len. Going to remove the first segment segments_len = self.segments_len() - while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len: - removed_len = self.segments[0].shape[0] / 16000 + while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len: + removed_len = self.state.segments[0].shape[0] / 16000 segments_len -= removed_len - self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) - self.cumulative_time_offset += removed_len # Track cumulative time removed - self.segments = self.segments[1:] - logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s") - if len(self.tokens) > 1: - self.context.append_token_ids(self.tokens[1][0,:].tolist()) - self.tokens = [self.initial_tokens] + self.tokens[2:] + self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len) + self.state.cumulative_time_offset += removed_len # Track cumulative time removed + self.state.segments = self.state.segments[1:] + logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s") + if len(self.state.tokens) > 1: + self.state.context.append_token_ids(self.state.tokens[1][0, :].tolist()) + self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:] return removed_len def _clean_cache(self): - '''clean the cache that stores the attention matrices and kv_cache. - It must be called every time after generation with the model.''' - # cleaning cache - self.dec_attns = [] - self.kv_cache = {} - if self.decoder_type == "beam": - self.inference.kv_cache = self.kv_cache - self.token_decoder.reset() + """Clean the kv_cache after each inference step.""" + self.state.clean_cache() @torch.no_grad() def lang_id(self, encoder_features): """Language detection from encoder features. - This code is trimmed and copy-pasted from whisper.decoding.detect_language . + This code is trimmed and copy-pasted from whisper.decoding.detect_language. """ - # forward pass using a single token, startoftranscript n_audio = encoder_features.shape[0] x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1] + # Note: don't use kv_cache for language detection logits = self.model.logits(x, encoder_features)[:, 0] # collect detected languages; suppress all non-language tokens @@ -392,19 +385,19 @@ class AlignAtt: @torch.no_grad() def infer(self, is_last=False): new_segment = True - if len(self.segments) == 0: + if len(self.state.segments) == 0: logger.debug("No segments, nothing to do") return [] if not self._apply_minseglen(): logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") - input_segments = torch.cat(self.segments, dim=0) + input_segments = torch.cat(self.state.segments, dim=0) return [] # input_segments is concatenation of audio, it's one array - if len(self.segments) > 1: - input_segments = torch.cat(self.segments, dim=0) + if len(self.state.segments) > 1: + input_segments = torch.cat(self.state.segments, dim=0) else: - input_segments = self.segments[0] + input_segments = self.state.segments[0] beg_encode = time() if self.use_mlcore: @@ -458,18 +451,18 @@ class AlignAtt: end_encode = time() # print('Encoder duration:', end_encode-beg_encode) - if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp: - seconds_since_start = self.segments_len() - self.first_timestamp + if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp: + seconds_since_start = self.segments_len() - self.state.first_timestamp if seconds_since_start >= 2.0: language_tokens, language_probs = self.lang_id(encoder_feature) top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) print(f"Detected language: {top_lan} with p={p:.4f}") self.create_tokenizer(top_lan) - self.last_attend_frame = -self.cfg.rewind_threshold - self.cumulative_time_offset = 0.0 + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 self.init_tokens() self.init_context() - self.detected_language = top_lan + self.state.detected_language = top_lan logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") self.trim_context() @@ -489,92 +482,80 @@ class AlignAtt: l_absolute_timestamps = [] - while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens + accumulated_cross_attns = [] + + while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens if new_segment: tokens_for_logits = current_tokens else: # only need to use the last token except in the first forward pass - tokens_for_logits = current_tokens[:,-1:] + tokens_for_logits = current_tokens[:, -1:] - logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size + # Get logits and cross-attention weights from decoder + result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) + logits, cross_attns = result + + # Accumulate cross-attention from this forward pass + accumulated_cross_attns.append(cross_attns) if new_segment and self.tokenizer.no_speech is not None: - probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1) + probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() if no_speech_probs[0] > self.cfg.nonspeech_prob: logger.info("no speech, stop") break - logits = logits[:, -1, :] # logits for the last token + logits = logits[:, -1, :] # logits for the last token - # supress blank tokens only at the beginning of the segment + # suppress blank tokens only at the beginning of the segment if new_segment: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf new_segment = False - self.suppress_tokens(logits) - current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs) + self.state.suppress_tokens_fn(logits) + current_tokens, completed = self.state.token_decoder.update(current_tokens, logits, sum_logprobs) logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") self.debug_print_tokens(current_tokens) - attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] - for i, attn_mat in enumerate(self.dec_attns): - layer_rank = int(i % len(self.model.decoder.blocks)) - align_heads_in_layer = self.align_source.get(layer_rank, []) - if len(align_heads_in_layer) == 0: - continue - for align_head_rank, head_id in align_heads_in_layer: - if self.cfg.beam_size == 1: - a = attn_mat[head_id, :, :] - a = a.unsqueeze(0) - else: - a = attn_mat[:, head_id, :, :] - attn_of_alignment_heads[align_head_rank].append(a) - tmp = [] - for mat in attn_of_alignment_heads: - t = torch.cat(mat, dim=1) - tmp.append(t) - attn_of_alignment_heads = torch.stack(tmp, dim=1) - std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False) - attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std - attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing - attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1) - attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len] + # Process accumulated cross-attention weights for alignment + attn_of_alignment_heads = self._process_cross_attention(accumulated_cross_attns, content_mel_len) # for each beam, the most attended frame is: - most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1) + most_attended_frames = torch.argmax(attn_of_alignment_heads[:, -1, :], dim=-1) # Calculate absolute timestamps accounting for cumulative offset - absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()] + absolute_timestamps = [ + (frame * 0.02 + self.state.cumulative_time_offset) + for frame in most_attended_frames.tolist() + ] logger.debug(str(most_attended_frames.tolist()) + " most att frames") - logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)") + logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.state.cumulative_time_offset:.2f}s)") most_attended_frame = most_attended_frames[0].item() l_absolute_timestamps.append(absolute_timestamps[0]) logger.debug("current tokens" + str(current_tokens.shape)) if completed: - # # stripping the last token, the eot + # stripping the last token, the eot current_tokens = current_tokens[:, :-1] break # for some rare cases where the attention fails - if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: - # TODO: check this + if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD: - logger.debug("ommit rewinding from special tokens") - self.last_attend_frame = most_attended_frame + logger.debug("omit rewinding from special tokens") + self.state.last_attend_frame = most_attended_frame else: logger.debug( f"[rewind detected] current attention pos: {most_attended_frame}, " - f"last attention pos: {self.last_attend_frame}; omit this segment") - self.last_attend_frame = -self.cfg.rewind_threshold - current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0] + f"last attention pos: {self.state.last_attend_frame}; omit this segment") + self.state.last_attend_frame = -self.cfg.rewind_threshold + current_tokens = torch.cat(self.state.tokens, dim=1) if len(self.state.tokens) > 0 else self.state.tokens[0] break else: - self.last_attend_frame = most_attended_frame + self.state.last_attend_frame = most_attended_frame if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold): logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}") @@ -594,12 +575,12 @@ class AlignAtt: tokens_to_split = current_tokens[0, token_len_before_decoding:] # Prepend pending tokens from previous chunk if any - if self.pending_incomplete_tokens: - logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}") - pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device) + if self.state.pending_incomplete_tokens: + logger.debug(f"[UTF-8 Fix] Prepending {len(self.state.pending_incomplete_tokens)} pending tokens: {self.state.pending_incomplete_tokens}") + pending_tensor = torch.tensor(self.state.pending_incomplete_tokens, dtype=torch.long, device=self.device) tokens_to_split = torch.cat([pending_tensor, tokens_to_split]) - if fire_detected or is_last: #or punctuation_stop: + if fire_detected or is_last: new_hypothesis = tokens_to_split.flatten().tolist() split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) else: @@ -610,20 +591,18 @@ class AlignAtt: else: new_hypothesis = [] - logger.debug(f"new_hypothesis: {new_hypothesis}") new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to( device=self.device, ) - self.tokens.append(new_tokens) + self.state.tokens.append(new_tokens) logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") self._clean_cache() - if len(l_absolute_timestamps) >=2 and self.first_timestamp is None: - self.first_timestamp = l_absolute_timestamps[0] - + if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None: + self.state.first_timestamp = l_absolute_timestamps[0] timestamped_words = [] timestamp_idx = 0 @@ -642,20 +621,85 @@ class AlignAtt: timestamp_idx += len(word_tokens) timestamp_entry = ASRToken( - start=round(current_timestamp, 2), - end=round(current_timestamp + 0.1, 2), - text= word, - speaker=self.speaker, - detected_language=self.detected_language - ).with_offset( - self.global_time_offset + start=round(current_timestamp, 2), + end=round(current_timestamp + 0.1, 2), + text=word, + speaker=self.state.speaker, + detected_language=self.state.detected_language + ).with_offset( + self.state.global_time_offset ) timestamped_words.append(timestamp_entry) # Hold incomplete tokens for next chunk - self.pending_incomplete_tokens = [] + self.state.pending_incomplete_tokens = [] if split_words and replacement_char in split_words[-1]: - self.pending_incomplete_tokens = split_tokens[-1] - logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}") + self.state.pending_incomplete_tokens = split_tokens[-1] + logger.warning(f"[UTF-8 Fix] Holding {len(self.state.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.state.pending_incomplete_tokens}") return timestamped_words + + def _process_cross_attention( + self, + cross_attns: List[torch.Tensor], + content_mel_len: int + ) -> torch.Tensor: + """ + Process cross-attention weights from decoder layers for alignment. + + Args: + cross_attns: List of cross-attention tensors from each decoder layer. + Each tensor has shape (batch, n_head, seq_len, audio_len) + content_mel_len: Length of actual audio content in mel frames + + Returns processed attention tensor for alignment, shape (batch, seq_len, content_mel_len) + """ + attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)] + num_decoder_layers = len(self.model.decoder.blocks) + + if cross_attns and isinstance(cross_attns[0], list): + flattened_attns: List[torch.Tensor] = [attn for layer_list in cross_attns for attn in layer_list] + else: + flattened_attns = cross_attns + + for idx, attn_mat in enumerate(flattened_attns): + layer_rank = idx % num_decoder_layers + # attn_mat shape: (batch, n_head, seq_len, audio_len) or (n_head, seq_len, audio_len) for batch=1 + align_heads_in_layer = self.state.align_source.get(layer_rank, []) + if len(align_heads_in_layer) == 0: + continue + + attn_mat = F.softmax(attn_mat, dim=-1) + + for align_head_rank, head_id in align_heads_in_layer: + if self.cfg.beam_size == 1: + # (n_head, seq_len, audio_len) when squeezed + if attn_mat.dim() == 4: + a = attn_mat[0, head_id, :, :] # (seq_len, audio_len) + else: + a = attn_mat[head_id, :, :] + a = a.unsqueeze(0) # (1, seq_len, audio_len) + else: + # attn_mat: (batch, n_head, seq_len, audio_len) + a = attn_mat[:, head_id, :, :] # (batch, seq_len, audio_len) + attn_of_alignment_heads[align_head_rank].append(a) + + tmp = [] + for mat in attn_of_alignment_heads: + if mat: + t = torch.cat(mat, dim=1) # (batch, total_seq_len, audio_len) + tmp.append(t) + + if not tmp: + return torch.zeros(self.cfg.beam_size, 1, content_mel_len, device=self.device) + + # stck al heads: (batch, num_align_heads, seq_len, audio_len) + attn_of_alignment_heads = torch.stack(tmp, dim=1) + + std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False) + attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8) + + attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) + attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1) + attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len] + return attn_of_alignment_heads diff --git a/whisperlivekit/whisper/decoding.py b/whisperlivekit/whisper/decoding.py index c494c72..1ef7bf7 100644 --- a/whisperlivekit/whisper/decoding.py +++ b/whisperlivekit/whisper/decoding.py @@ -147,16 +147,13 @@ class PyTorchInference(Inference): self.model: "Whisper" = model self.initial_token_length = initial_token_length self.kv_cache = {} - self.hooks = [] - key_modules = [block.attn.key for block in self.model.decoder.blocks] - value_modules = [block.attn.value for block in self.model.decoder.blocks] - self.kv_modules = key_modules + value_modules + self.kv_cache_ids = [] + for block in self.model.decoder.blocks: + self.kv_cache_ids.append(block.attn.key_cache_id) + self.kv_cache_ids.append(block.attn.value_cache_id) def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: - if not self.kv_cache: - self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() - if tokens.shape[-1] > self.initial_token_length: # only need to use the last token except in the first forward pass tokens = tokens[:, -1:] @@ -164,17 +161,14 @@ class PyTorchInference(Inference): return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) def cleanup_caching(self): - for hook in self.hooks: - hook.remove() - self.kv_cache = {} - self.hooks = [] def rearrange_kv_cache(self, source_indices): if source_indices != list(range(len(source_indices))): - for module in self.kv_modules: - # update the key/value cache to contain the selected sequences - self.kv_cache[module] = self.kv_cache[module][source_indices].detach() + for cache_id in self.kv_cache_ids: + if cache_id in self.kv_cache: + # update the key/value cache to contain the selected sequences + self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach() class SequenceRanker: diff --git a/whisperlivekit/whisper/model.py b/whisperlivekit/whisper/model.py index b6482a6..2d0a298 100644 --- a/whisperlivekit/whisper/model.py +++ b/whisperlivekit/whisper/model.py @@ -79,18 +79,23 @@ def disable_sdpa(): class MultiHeadAttention(nn.Module): - use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks + use_sdpa = False # Disable SDPA to ensure qk is always computed when needed - def __init__(self, n_state: int, n_head: int, cache_id: str = ""): + def __init__(self, n_state: int, n_head: int, cache_id: str = "", n_text_ctx: int = 448): super().__init__() self.n_head = n_head + self.n_text_ctx = n_text_ctx self.query = Linear(n_state, n_state) self.key = Linear(n_state, n_state, bias=False) self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) self.cache_id = cache_id - self.key.cache_id = f"{cache_id}_key" - self.value.cache_id = f"{cache_id}_value" + # Cache IDs for key and value (used with dict-based kv_cache) + self.key_cache_id = f"{cache_id}_key" + self.value_cache_id = f"{cache_id}_value" + # Keep these for backward compatibility with hook-based caching + self.key.cache_id = self.key_cache_id + self.value.cache_id = self.value_cache_id def forward( self, @@ -101,19 +106,45 @@ class MultiHeadAttention(nn.Module): ): q = self.query(x) - if kv_cache is None or xa is None or self.key not in kv_cache: - # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; - # otherwise, perform key/value projections for self- or cross-attention as usual. - k = self.key(x if xa is None else xa) - v = self.value(x if xa is None else xa) + if xa is None: + # Self-attention + k = self.key(x) + v = self.value(x) + if kv_cache is not None: + k, v = self._update_self_attn_cache(k, v, kv_cache) else: - # for cross-attention, calculate keys and values once and reuse in subsequent calls. - k = kv_cache[self.key] - v = kv_cache[self.value] + # Cross-attention: compute once and cache, or reuse from cache + if kv_cache is not None and self.key_cache_id in kv_cache: + k = kv_cache[self.key_cache_id] + v = kv_cache[self.value_cache_id] + else: + k = self.key(xa) + v = self.value(xa) + if kv_cache is not None: + kv_cache[self.key_cache_id] = k + kv_cache[self.value_cache_id] = v wv, qk = self.qkv_attention(q, k, v, mask) return self.out(wv), qk + def _update_self_attn_cache( + self, k: Tensor, v: Tensor, kv_cache: dict + ) -> Tuple[Tensor, Tensor]: + """Update self-attention kv cache by concatenating new k,v with cached values.""" + if self.key_cache_id not in kv_cache or k.shape[1] > self.n_text_ctx: + # First token or context overflow: save as-is + kv_cache[self.key_cache_id] = k.detach() + kv_cache[self.value_cache_id] = v.detach() + else: + # Concatenate with existing cache + cached_k = kv_cache[self.key_cache_id] + cached_v = kv_cache[self.value_cache_id] + k = torch.cat([cached_k, k], dim=1).detach() + v = torch.cat([cached_v, v], dim=1).detach() + kv_cache[self.key_cache_id] = k + kv_cache[self.value_cache_id] = v + return k, v + def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -143,14 +174,21 @@ class MultiHeadAttention(nn.Module): class ResidualAttentionBlock(nn.Module): - def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""): + def __init__( + self, n_state: int, n_head: int, cross_attention: bool = False, + cache_id: str = "", n_text_ctx: int = 448 + ): super().__init__() - self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn") + self.attn = MultiHeadAttention( + n_state, n_head, cache_id=f"{cache_id}_self_attn", n_text_ctx=n_text_ctx + ) self.attn_ln = LayerNorm(n_state) self.cross_attn = ( - MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None + MultiHeadAttention( + n_state, n_head, cache_id=f"{cache_id}_cross_attn", n_text_ctx=n_text_ctx + ) if cross_attention else None ) self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None @@ -166,12 +204,21 @@ class ResidualAttentionBlock(nn.Module): xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, - ): + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Returns: + x: The output tensor + cross_attn_qk: Cross-attention weights (if cross_attn exists), else None + """ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + cross_attn_qk = None if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + cross_out, cross_attn_qk = self.cross_attn( + self.cross_attn_ln(x), xa, kv_cache=kv_cache + ) + x = x + cross_out x = x + self.mlp(self.mlp_ln(x)) - return x + return x, cross_attn_qk class AudioEncoder(nn.Module): @@ -201,7 +248,7 @@ class AudioEncoder(nn.Module): x = (x + self.positional_embedding).to(x.dtype) for block in self.blocks: - x = block(x) + x, _ = block(x) # Encoder blocks don't have cross-attention x = self.ln_post(x) return x @@ -212,13 +259,17 @@ class TextDecoder(nn.Module): self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() + self.n_ctx = n_ctx self.token_embedding = nn.Embedding(n_vocab, n_state) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ - ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}") + ResidualAttentionBlock( + n_state, n_head, cross_attention=True, + cache_id=f"dec_layer{i}", n_text_ctx=n_ctx + ) for i in range(n_layer) ] ) @@ -227,28 +278,57 @@ class TextDecoder(nn.Module): mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + def forward( + self, + x: Tensor, + xa: Tensor, + kv_cache: Optional[dict] = None, + return_cross_attn: bool = False, + ): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on + kv_cache : Optional[dict] + Dictionary to store/retrieve key-value cache for efficient decoding + return_cross_attn : bool + If True, return cross-attention weights from all decoder layers + + Returns + ------- + logits : Tensor + The output logits + cross_attns : Optional[List[Tensor]] + List of cross-attention weights per layer (only if return_cross_attn=True) """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + # Calculate offset from self-attention cache (not cross-attention which has audio length) + offset = 0 + if kv_cache: + # Use the first decoder block's self-attention key cache to get token position + first_self_attn_key = self.blocks[0].attn.key_cache_id + if first_self_attn_key in kv_cache: + offset = kv_cache[first_self_attn_key].shape[1] + x = ( self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] ) x = x.to(xa.dtype) + cross_attns = [] if return_cross_attn else None for block in self.blocks: - x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + x, cross_attn_qk = block(x, xa, mask=self.mask, kv_cache=kv_cache) + if return_cross_attn and cross_attn_qk is not None: + cross_attns.append(cross_attn_qk) x = self.ln(x) logits = ( x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) ).float() + if return_cross_attn: + return logits, cross_attns return logits @@ -292,8 +372,18 @@ class Whisper(nn.Module): def embed_audio(self, mel: torch.Tensor): return self.encoder(mel) - def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): - return self.decoder(tokens, audio_features) + def logits( + self, + tokens: torch.Tensor, + audio_features: torch.Tensor, + kv_cache: Optional[dict] = None, + return_cross_attn: bool = False, + ): + return self.decoder( + tokens, audio_features, + kv_cache=kv_cache, + return_cross_attn=return_cross_attn + ) def forward( self, mel: torch.Tensor, tokens: torch.Tensor @@ -312,39 +402,6 @@ class Whisper(nn.Module): def num_languages(self): return self.dims.n_vocab - 51765 - int(self.is_multilingual) - def install_kv_cache_hooks(self, cache: Optional[dict] = None): - """ - The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value - tensors calculated for the previous positions. This method returns a dictionary that stores - all caches, and the necessary hooks for the key and value projection modules that save the - intermediate tensors to be reused during later calculations. - - Returns - ------- - cache : Dict[nn.Module, torch.Tensor] - A dictionary object mapping the key/value projection modules to its cache - hooks : List[RemovableHandle] - List of PyTorch RemovableHandle objects to stop the hooks to be called - """ - cache = {**cache} if cache is not None else {} - hooks = [] - - def save_to_cache(module, _, output): - if module not in cache or output.shape[1] > self.dims.n_text_ctx: - # save as-is, for the first token or cross attention - cache[module] = output - else: - cache[module] = torch.cat([cache[module], output], dim=1).detach() - return cache[module] - - def install_hooks(layer: nn.Module): - if isinstance(layer, MultiHeadAttention): - hooks.append(layer.key.register_forward_hook(save_to_cache)) - hooks.append(layer.value.register_forward_hook(save_to_cache)) - - self.decoder.apply(install_hooks) - return cache, hooks - detect_language = detect_language_function transcribe = transcribe_function decode = decode_function