From fac8659161ab902b25d56831ecc0bcf2de83926f Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 21 Nov 2025 23:52:00 +0100 Subject: [PATCH] uses native mlx function for attention --- whisperlivekit/simul_whisper/mlx/decoders.py | 219 +++++ .../simul_whisper/mlx/simul_whisper.py | 752 ++++++++++++++++++ 2 files changed, 971 insertions(+) create mode 100644 whisperlivekit/simul_whisper/mlx/decoders.py create mode 100644 whisperlivekit/simul_whisper/mlx/simul_whisper.py diff --git a/whisperlivekit/simul_whisper/mlx/decoders.py b/whisperlivekit/simul_whisper/mlx/decoders.py new file mode 100644 index 0000000..58c3b8c --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx/decoders.py @@ -0,0 +1,219 @@ +""" +MLX-native token decoders for streaming ASR. +""" +from typing import Any, Dict, List, Optional, Tuple + +import mlx.core as mx +import numpy as np + + +class MLXGreedyDecoder: + """Greedy decoder using MLX operations.""" + + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update( + self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array + ) -> Tuple[mx.array, bool]: + """ + Update tokens with next predicted token. + + Args: + tokens: Current token sequence, shape (batch, seq_len) + logits: Logits for next token, shape (batch, vocab_size) + sum_logprobs: Cumulative log probabilities, shape (batch,) + + Returns: + Updated tokens and completion flag + """ + if self.temperature == 0: + next_tokens = mx.argmax(logits, axis=-1) + else: + probs = mx.softmax(logits / self.temperature, axis=-1) + next_tokens = mx.random.categorical(mx.log(probs + 1e-10)) + + logprobs = mx.softmax(logits, axis=-1) + logprobs = mx.log(logprobs + 1e-10) + batch_size = logprobs.shape[0] + current_logprobs = logprobs[mx.arange(batch_size), next_tokens] + mask = (tokens[:, -1] != self.eot).astype(mx.float32) + sum_logprobs = sum_logprobs + current_logprobs * mask + eot_mask = (tokens[:, -1] == self.eot) + next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens) + tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1) + completed = bool(mx.all(tokens[:, -1] == self.eot)) + + return tokens, completed + + def finalize(self, tokens: mx.array, sum_logprobs: mx.array): + """Finalize decoding by ensuring EOT at end.""" + eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype) + tokens = mx.concatenate([tokens, eot_column], axis=1) + return tokens, sum_logprobs.tolist() + + +class MLXBeamSearchDecoder: + """Beam search decoder using MLX operations.""" + + def __init__( + self, + beam_size: int, + eot: int, + inference: Any, + patience: Optional[float] = None, + ): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences: Optional[List[Dict]] = None + + assert ( + self.max_candidates > 0 + ), f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + """Reset finished sequences for new segment.""" + self.finished_sequences = None + + def update( + self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array + ) -> Tuple[mx.array, bool]: + """ + Update tokens using beam search. + + Args: + tokens: Current token sequences, shape (batch * beam_size, seq_len) + logits: Logits for next token, shape (batch * beam_size, vocab_size) + sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,) + + Returns: + Updated tokens and completion flag + """ + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: + self.finished_sequences = [{} for _ in range(n_audio)] + logprobs = mx.softmax(logits, axis=-1) + logprobs = mx.log(logprobs + 1e-10) + logprobs_np = np.array(logprobs) + tokens_np = np.array(tokens) + sum_logprobs_np = np.array(sum_logprobs) + + next_tokens, source_indices, finished_sequences = [], [], [] + new_sum_logprobs = [] + + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens_np[idx].tolist() + top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1] + + for token_idx in top_k_indices: + logprob = logprobs_np[idx, token_idx] + new_logprob = sum_logprobs_np[idx] + logprob + sequence = tuple(prefix + [int(token_idx)]) + scores[sequence] = new_logprob + sources[sequence] = idx + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + new_sum_logprobs.append(scores[sequence]) + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + tokens = mx.array(np.array(next_tokens, dtype=np.int32)) + sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32)) + self.inference.rearrange_kv_cache(source_indices) + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip( + self.finished_sequences, finished_sequences + ): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break + previously_finished[seq] = newly_finished[seq] + completed = all( + len(sequences) >= self.max_candidates + for sequences in self.finished_sequences + ) + + return tokens, completed + + def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array): + """Finalize beam search by selecting best sequences.""" + preceding_tokens_np = np.array(preceding_tokens) + sum_logprobs_np = np.array(sum_logprobs) + + n_audio = preceding_tokens_np.shape[0] // self.beam_size + tokens_list: List[List[int]] = [[] for _ in range(n_audio)] + sum_logprobs_list: List[float] = [0.0] * n_audio + + for i, sequences in enumerate(self.finished_sequences): + if sequences: + best_seq = max(sequences, key=sequences.get) + tokens_list[i] = list(best_seq) + sum_logprobs_list[i] = sequences[best_seq] + else: + idx = i * self.beam_size + tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot] + sum_logprobs_list[i] = float(sum_logprobs_np[idx]) + max_len = max(len(t) for t in tokens_list) + for i, t in enumerate(tokens_list): + tokens_list[i] = t + [self.eot] * (max_len - len(t)) + + tokens = mx.array(np.array(tokens_list, dtype=np.int32)) + return tokens, sum_logprobs_list + + +class MLXInference: + """MLX inference wrapper for beam search KV cache management.""" + + def __init__(self, model, initial_token_length: int): + self.model = model + self.initial_token_length = initial_token_length + self.kv_cache = None + + def rearrange_kv_cache(self, source_indices: List[int]): + """Rearrange KV cache based on beam search source indices.""" + if self.kv_cache is None: + return + + if source_indices == list(range(len(source_indices))): + return + + source_indices_mx = mx.array(source_indices, dtype=mx.int32) + + new_cache = [] + for layer_cache in self.kv_cache: + (k, v), (cross_k, cross_v) = layer_cache + new_k = k[source_indices_mx] + new_v = v[source_indices_mx] + new_cache.append(((new_k, new_v), (cross_k, cross_v))) + + self.kv_cache = new_cache + + def logits( + self, + tokens: mx.array, + audio_features: mx.array, + ) -> Tuple[mx.array, List]: + """Get logits from decoder with KV cache.""" + logits, self.kv_cache, cross_qk = self.model.decoder( + tokens, audio_features, kv_cache=self.kv_cache + ) + return logits, cross_qk + diff --git a/whisperlivekit/simul_whisper/mlx/simul_whisper.py b/whisperlivekit/simul_whisper/mlx/simul_whisper.py new file mode 100644 index 0000000..6d075cc --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx/simul_whisper.py @@ -0,0 +1,752 @@ +""" +MLX whisper AlignAtt streaming decoder +""" +import logging +from time import time +from typing import Any, List, Optional, Tuple + +import mlx.core as mx +import numpy as np + +from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram +from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim + +from whisperlivekit.timed_objects import ASRToken +from whisperlivekit.whisper import DecodingOptions, tokenizer +from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND + +from ..config import AlignAttConfig +from .decoder_state import MLXDecoderState +from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference + +DEC_PAD = 50257 +logger = logging.getLogger(__name__) + + +class MLXTokenBuffer: + """Token buffer for MLX-based decoding.""" + + def __init__(self, text="", tokenizer=None, prefix_token_ids=None): + self.text = text + self.prefix_token_ids = prefix_token_ids or [] + self.tokenizer = tokenizer + self.pending_token_ids = [] + + def as_token_ids(self, tokenizer=None): + if tokenizer is None: + tokenizer = self.tokenizer + if tokenizer is None: + raise ValueError("Tokenizer is not set.") + return self.prefix_token_ids + tokenizer.encode(self.text) + + def as_mlx_array(self) -> mx.array: + """Return tokens as MLX array.""" + tok_ids = self.as_token_ids() + return mx.array([tok_ids], dtype=mx.int32) + + def as_mlx_array_beam(self, beam: int) -> mx.array: + """Return tokens as MLX array repeated for beam search.""" + t = self.as_mlx_array() + return mx.repeat(t, beam, axis=0) + + def as_text(self): + return self.text + + @staticmethod + def empty(*a, **kw): + return MLXTokenBuffer(*a, **kw) + + @staticmethod + def from_text(text, *a, **kw): + return MLXTokenBuffer(*a, text=text, **kw) + + def is_empty(self): + return self.text is None or self.text == "" + + def trim_words(self, num=1, after=0): + """Trim words from the beginning of the context.""" + tokenizer = self.tokenizer + assert tokenizer is not None, "Tokenizer is not set." + + ids = tokenizer.encode(self.text[after:]) + words, wids = self.tokenizer.split_to_word_tokens(ids) + if not words: + return 0 + self.text = self.text[:after] + "".join(words[num:]) + return sum(len(wi) for wi in wids[:num]) + + def append_token_ids(self, token_ids): + """Append token IDs to the buffer, handling incomplete UTF-8.""" + tokenizer = self.tokenizer + assert tokenizer is not None, "Tokenizer is not set." + + all_tokens = self.pending_token_ids + token_ids + decoded = tokenizer.decode(all_tokens) + replacement_char = "\ufffd" + + if replacement_char in decoded: + if len(all_tokens) > 1: + decoded_partial = tokenizer.decode(all_tokens[:-1]) + if replacement_char not in decoded_partial: + self.text += decoded_partial + self.pending_token_ids = [all_tokens[-1]] + else: + self.pending_token_ids = all_tokens + else: + self.pending_token_ids = all_tokens + else: + self.text += decoded + self.pending_token_ids = [] + + +def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array: + """ + Apply median filter along the last axis. + + Args: + x: Input array of shape (..., T) + filter_width: Width of the median filter (should be odd) + + Returns: + Filtered array of same shape + """ + if filter_width <= 1: + return x + + pad_width = filter_width // 2 + shape = x.shape + + left_pad = mx.repeat(x[..., :1], pad_width, axis=-1) + right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1) + x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1) + + result_shape = list(shape) + result = [] + + for i in range(shape[-1]): + window = x_padded[..., i:i + filter_width] + sorted_window = mx.sort(window, axis=-1) + median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1] + result.append(median_val) + + return mx.concatenate(result, axis=-1) + + +class MLXAlignAtt: + """ + MLX-native Alignment-based Attention decoder for SimulStreaming. + + This class runs entirely on MLX, with no PyTorch dependencies for inference. + """ + + @property + def speaker(self): + return self.state.speaker + + @speaker.setter + def speaker(self, value): + self.state.speaker = value + + @property + def global_time_offset(self): + return self.state.global_time_offset + + @global_time_offset.setter + def global_time_offset(self, value): + self.state.global_time_offset = value + + def __init__( + self, + cfg: AlignAttConfig, + mlx_model: Any, + ) -> None: + """ + Initialize MLX AlignAtt decoder. + + Args: + cfg: AlignAtt configuration + mlx_model: MLX Whisper model (full model, not just encoder) + """ + self.model = mlx_model + self.cfg = cfg + + logger.info(f"MLX Model dimensions: {self.model.dims}") + + self.decode_options = DecodingOptions( + language=cfg.language, + without_timestamps=True, + task=cfg.task + ) + self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual + + self.max_text_len = self.model.dims.n_text_ctx + self.num_decoder_layers = len(self.model.decoder.blocks) + + if self.cfg.max_context_tokens is None: + self.max_context_tokens = self.max_text_len + else: + self.max_context_tokens = self.cfg.max_context_tokens + + # Initialize per-session state + self.state = MLXDecoderState() + self._init_state(cfg) + + def _init_state(self, cfg: AlignAttConfig): + """Initialize the per-session decoder state.""" + self.create_tokenizer(cfg.language if cfg.language != "auto" else None) + self.state.tokenizer = self.tokenizer + self.state.detected_language = cfg.language if cfg.language != "auto" else None + self.state.global_time_offset = 0.0 + self.state.last_attend_frame = -cfg.rewind_threshold + self.state.speaker = -1 + + if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path: + if cfg.never_fire: + self.state.never_fire = True + self.state.always_fire = False + else: + self.state.always_fire = True + self.state.never_fire = False + else: + logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True") + self.state.always_fire = True + self.state.never_fire = cfg.never_fire + + self._build_alignment_source() + + suppress_tokens = [ + self.tokenizer.transcribe, + self.tokenizer.translate, + self.tokenizer.sot, + self.tokenizer.sot_prev, + self.tokenizer.sot_lm, + self.tokenizer.no_timestamps, + ] + list(self.tokenizer.all_language_tokens) + if self.tokenizer.no_speech is not None: + suppress_tokens.append(self.tokenizer.no_speech) + self.state.suppress_tokens = tuple(sorted(set(suppress_tokens))) + logger.debug(f"Suppress tokens: {self.state.suppress_tokens}") + + self.init_tokens() + self.init_context() + + self.state.decoder_type = cfg.decoder_type + if cfg.decoder_type == "greedy": + logger.info("Using MLX greedy decoder") + self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot) + elif cfg.decoder_type == "beam": + logger.info("Using MLX beam decoder") + self.state.inference = MLXInference(self.model, self.state.initial_token_length) + self.state.token_decoder = MLXBeamSearchDecoder( + inference=self.state.inference, + eot=self.tokenizer.eot, + beam_size=cfg.beam_size + ) + + def _build_alignment_source(self): + """Build alignment source mapping from model's alignment_heads.""" + self.state.align_source = {} + self.state.num_align_heads = 0 + + alignment_heads = self.model.alignment_heads + + if alignment_heads is None: + logger.warning("No alignment heads found in model") + return + + if hasattr(alignment_heads, 'tolist'): + heads_list = alignment_heads.tolist() + else: + heads_list = np.array(alignment_heads).tolist() + + for layer_rank, head_id in heads_list: + layer_rank = int(layer_rank) + head_id = int(head_id) + heads = self.state.align_source.get(layer_rank, []) + heads.append((self.state.num_align_heads, head_id)) + self.state.align_source[layer_rank] = heads + self.state.num_align_heads += 1 + + def warmup(self, audio: np.ndarray): + """Warmup the model with sample audio.""" + try: + self.insert_audio(audio) + self.infer(is_last=True) + self.refresh_segment(complete=True) + logger.info("MLX model warmed up successfully") + except Exception as e: + logger.exception(f"MLX model warmup failed: {e}") + + def create_tokenizer(self, language=None): + """Create tokenizer for the given language.""" + self.tokenizer = tokenizer.get_tokenizer( + multilingual=self.tokenizer_is_multilingual, + language=language, + num_languages=self.model.num_languages, + task=self.decode_options.task + ) + self.state.tokenizer = self.tokenizer + + def init_context(self): + """Initialize context buffer.""" + kw = { + 'tokenizer': self.tokenizer, + 'prefix_token_ids': [self.tokenizer.sot_prev] + } + self.state.context = MLXTokenBuffer.empty(**kw) + if self.cfg.static_init_prompt is not None: + self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw) + if self.cfg.init_prompt is not None: + self.state.context.text += self.cfg.init_prompt + + def init_tokens(self): + """Initialize token sequence.""" + logger.debug(f"init tokens, {len(self.state.segments)}") + self.state.initial_tokens = mx.array( + [self.tokenizer.sot_sequence_including_notimestamps], + dtype=mx.int32 + ) + self.state.initial_token_length = self.state.initial_tokens.shape[1] + self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot) + logger.debug(f"init tokens after, {len(self.state.segments)}") + self.state.tokens = [self.state.initial_tokens] + + def trim_context(self): + """Trim context if too long.""" + logger.info("Trimming context") + c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids) + logger.info(f"Context text: {self.state.context.as_text()}") + l = sum(t.shape[1] for t in self.state.tokens) + c + if self.cfg.static_init_prompt is None: + after = 0 + else: + after = len(self.cfg.static_init_prompt) + while c > self.max_context_tokens or l > self.max_text_len - 20: + t = self.state.context.trim_words(after=after) + l -= t + c -= t + logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}") + if t == 0: + break + logger.info(f"Context after trim: {self.state.context.text} (len: {l})") + + def refresh_segment(self, complete=False): + """Refresh segment state.""" + logger.debug("Refreshing segment:") + self.init_tokens() + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 + self.init_context() + logger.debug(f"Context: {self.state.context}") + if not complete and len(self.state.segments) > 2: + self.state.segments = self.state.segments[-2:] + else: + logger.debug("removing all segments.") + self.state.segments = [] + self.state.log_segments += 1 + self.state.pending_incomplete_tokens = [] + + def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool: + """Check if we should fire at word boundary (CIF-based).""" + if self.state.always_fire: + return True + if self.state.never_fire: + return False + return True + + def _current_tokens(self) -> mx.array: + """Get current token sequence for decoding.""" + toks = self.state.tokens + + if toks[0].shape[0] == 1: + toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0) + + if not self.state.context.is_empty(): + context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size) + toks = [context_toks] + toks + + # Concatenate all tokens + if len(toks) > 1: + current_tokens = mx.concatenate(toks, axis=1) + else: + current_tokens = toks[0] + + logger.debug("debug print current_tokens:") + self.debug_print_tokens(current_tokens) + return current_tokens + + def debug_print_tokens(self, tokens: mx.array): + """Debug print token sequences.""" + tokens_np = np.array(tokens) + for i in range(min(self.cfg.beam_size, tokens_np.shape[0])): + logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist())) + + def segments_len(self) -> float: + """Get total length of audio segments in seconds.""" + return sum(s.shape[0] for s in self.state.segments) / 16000 + + def _apply_minseglen(self) -> bool: + """Check if we have enough audio to process.""" + segments_len = self.segments_len() + if segments_len < self.cfg.audio_min_len: + logger.debug("waiting for next segment") + return False + return True + + def insert_audio(self, segment: np.ndarray = None): + """Insert audio segment into buffer.""" + if segment is not None: + if hasattr(segment, 'numpy'): + segment = segment.numpy() + self.state.segments.append(segment) + + removed_len = 0 + segments_len = self.segments_len() + + while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len: + removed_len = self.state.segments[0].shape[0] / 16000 + segments_len -= removed_len + self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len) + self.state.cumulative_time_offset += removed_len + self.state.segments = self.state.segments[1:] + logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s") + + if len(self.state.tokens) > 1: + # Convert MLX array to list for context + token_list = np.array(self.state.tokens[1][0, :]).tolist() + self.state.context.append_token_ids(token_list) + self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:] + + return removed_len + + def _clean_cache(self): + """Clean the kv_cache after each inference step.""" + self.state.clean_cache() + + def _suppress_tokens(self, logits: mx.array) -> mx.array: + """Apply token suppression to logits.""" + if self.state.suppress_tokens: + suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32) + logits = logits.at[:, suppress_indices].add(-float('inf')) + return logits + + def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]: + """Language detection from encoder features.""" + n_audio = encoder_features.shape[0] + x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32) + + logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None) + logits = logits[:, 0] + + mask = mx.ones(logits.shape[-1], dtype=mx.bool_) + language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32) + mask = mask.at[language_token_indices].add(False) + + logits = mx.where(mask, mx.array(-float('inf')), logits) + + language_tokens = mx.argmax(logits, axis=-1) + language_token_probs = mx.softmax(logits, axis=-1) + + probs_np = np.array(language_token_probs) + + language_probs = [ + { + c: float(probs_np[i, j]) + for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes) + } + for i in range(n_audio) + ] + + self._clean_cache() + return language_tokens, language_probs + + def infer(self, is_last: bool = False) -> List[ASRToken]: + """ + Main inference method. + + Args: + is_last: Whether this is the final chunk + + Returns: + List of timestamped ASR tokens + """ + new_segment = True + + if len(self.state.segments) == 0: + logger.debug("No segments, nothing to do") + return [] + + if not self._apply_minseglen(): + logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") + return [] + + if len(self.state.segments) > 1: + input_segments = np.concatenate(self.state.segments, axis=0) + else: + input_segments = self.state.segments[0] + + beg_encode = time() + + mlx_mel_padded = mlx_log_mel_spectrogram( + audio=input_segments, + n_mels=self.model.dims.n_mels, + padding=N_SAMPLES + ) + mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) + encoder_feature = self.model.encoder(mlx_mel[None]) + content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2) + + mx.eval(encoder_feature) + + end_encode = time() + logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s') + + if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp: + seconds_since_start = self.segments_len() - self.state.first_timestamp + if seconds_since_start >= 2.0: + language_tokens, language_probs = self.lang_id(encoder_feature) + top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) + print(f"Detected language: {top_lan} with p={p:.4f}") + self.create_tokenizer(top_lan) + self.state.last_attend_frame = -self.cfg.rewind_threshold + self.state.cumulative_time_offset = 0.0 + self.init_tokens() + self.init_context() + self.state.detected_language = top_lan + logger.info(f"Tokenizer language: {self.tokenizer.language}") + + self.trim_context() + current_tokens = self._current_tokens() + + fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :]) + + sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32) + completed = False + + attn_of_alignment_heads = None + most_attended_frame = None + + token_len_before_decoding = current_tokens.shape[1] + + l_absolute_timestamps = [] + accumulated_cross_attns = [] + + audio_duration_s = self.segments_len() + max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) + tokens_produced_this_chunk = 0 + + while not completed and current_tokens.shape[1] < self.max_text_len: + tokens_produced_this_chunk += 1 + + if tokens_produced_this_chunk > max_tokens_per_chunk: + logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.") + current_tokens = current_tokens[:, :token_len_before_decoding] + break + + if new_segment: + tokens_for_logits = current_tokens + else: + tokens_for_logits = current_tokens[:, -1:] + + if self.state.decoder_type == "greedy": + logits, self.state.kv_cache, cross_qk = self.model.decoder( + tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache + ) + else: + logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature) + + mx.eval(logits) + + accumulated_cross_attns.append(cross_qk) + + if new_segment and self.tokenizer.no_speech is not None: + probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1) + no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist() + if no_speech_probs[0] > self.cfg.nonspeech_prob: + logger.info("no speech, stop") + break + + logits = logits[:, -1, :] # Last token logits + + # Suppress tokens at segment start + if new_segment: + blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot] + logits = logits.at[:, blank_tokens].add(-float('inf')) + new_segment = False + + logits = self._suppress_tokens(logits) + + current_tokens, completed = self.state.token_decoder.update( + current_tokens, logits, sum_logprobs + ) + mx.eval(current_tokens) + + logger.debug(f"Decoding completed: {completed}") + self.debug_print_tokens(current_tokens) + + attn_of_alignment_heads = self._process_cross_attention( + accumulated_cross_attns, content_mel_len + ) + + most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1) + most_attended_frames_np = np.array(most_attended_frames) + + absolute_timestamps = [ + (frame * 0.02 + self.state.cumulative_time_offset) + for frame in most_attended_frames_np.tolist() + ] + + logger.debug(str(most_attended_frames_np.tolist()) + " most att frames") + logger.debug(f"Absolute timestamps: {absolute_timestamps}") + + most_attended_frame = int(most_attended_frames_np[0]) + l_absolute_timestamps.append(absolute_timestamps[0]) + + if completed: + current_tokens = current_tokens[:, :-1] + break + if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold: + current_tokens_np = np.array(current_tokens) + if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD: + logger.debug("omit rewinding from special tokens") + self.state.last_attend_frame = most_attended_frame + else: + logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}") + self.state.last_attend_frame = -self.cfg.rewind_threshold + current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0] + break + else: + self.state.last_attend_frame = most_attended_frame + if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold): + logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}") + current_tokens = current_tokens[:, :-1] + break + tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist() + if self.state.pending_incomplete_tokens: + logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}") + tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split + + if fire_detected or is_last: + new_hypothesis = tokens_to_split + split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis) + else: + split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split) + if len(split_words) > 1: + new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist] + else: + new_hypothesis = [] + + logger.debug(f"new_hypothesis: {new_hypothesis}") + new_tokens = mx.array([new_hypothesis], dtype=mx.int32) + new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0) + self.state.tokens.append(new_tokens) + + logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}") + + self._clean_cache() + + if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None: + self.state.first_timestamp = l_absolute_timestamps[0] + timestamped_words = [] + timestamp_idx = 0 + replacement_char = "\ufffd" + + for word, word_tokens in zip(split_words, split_tokens): + if replacement_char in word: + logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}") + timestamp_idx += len(word_tokens) + continue + + try: + current_timestamp = l_absolute_timestamps[timestamp_idx] + except IndexError: + pass + timestamp_idx += len(word_tokens) + + timestamp_entry = ASRToken( + start=round(current_timestamp, 2), + end=round(current_timestamp + 0.1, 2), + text=word, + speaker=self.state.speaker, + detected_language=self.state.detected_language + ).with_offset(self.state.global_time_offset) + timestamped_words.append(timestamp_entry) + self.state.pending_incomplete_tokens = [] + MAX_PENDING_TOKENS = 10 + if split_words and replacement_char in split_words[-1]: + if len(split_tokens[-1]) <= MAX_PENDING_TOKENS: + self.state.pending_incomplete_tokens = split_tokens[-1] + logger.debug(f"[UTF-8 Fix] Holding incomplete tokens") + else: + logger.warning(f"[UTF-8 Fix] Skipping too many tokens") + + return timestamped_words + + def _process_cross_attention( + self, + cross_attns: List[List[mx.array]], + content_mel_len: int + ) -> mx.array: + """ + Process cross-attention weights for alignment. + + Args: + cross_attns: List of cross-attention from each forward pass + Each element is a list of mx.arrays per layer + content_mel_len: Length of actual audio content + + Returns: + Processed attention tensor, shape (batch, seq_len, content_mel_len) + """ + attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)] + num_decoder_layers = self.num_decoder_layers + + if cross_attns and isinstance(cross_attns[0], list): + flattened_attns = [attn for layer_list in cross_attns for attn in layer_list] + else: + flattened_attns = cross_attns + + for idx, attn_mat in enumerate(flattened_attns): + if attn_mat is None: + continue + + layer_rank = idx % num_decoder_layers + align_heads_in_layer = self.state.align_source.get(layer_rank, []) + + if len(align_heads_in_layer) == 0: + continue + attn_mat = mx.softmax(attn_mat, axis=-1) + + for align_head_rank, head_id in align_heads_in_layer: + if self.cfg.beam_size == 1: + if attn_mat.ndim == 4: + a = attn_mat[0, head_id, :, :] + else: + a = attn_mat[head_id, :, :] + a = a[None, :, :] + else: + a = attn_mat[:, head_id, :, :] + attn_of_alignment_heads[align_head_rank].append(a) + tmp = [] + for mat in attn_of_alignment_heads: + if mat: + t = mx.concatenate(mat, axis=1) + tmp.append(t) + + if not tmp: + return mx.zeros((self.cfg.beam_size, 1, content_mel_len)) + attn_of_alignment_heads = mx.stack(tmp, axis=1) + + std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True) + mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True) + attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8) + + attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7) + + attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1) + + attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len] + + mx.eval(attn_of_alignment_heads) + return attn_of_alignment_heads +