diff --git a/whisperlivekit/simul_whisper/mlx/__init__.py b/whisperlivekit/simul_whisper/mlx/__init__.py new file mode 100644 index 0000000..95be0b5 --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx/__init__.py @@ -0,0 +1,11 @@ +from .decoder_state import MLXDecoderState +from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference +from .simul_whisper import MLXAlignAtt + +__all__ = [ + "MLXAlignAtt", + "MLXBeamSearchDecoder", + "MLXDecoderState", + "MLXGreedyDecoder", + "MLXInference", +] diff --git a/whisperlivekit/simul_whisper/mlx/decoder_state.py b/whisperlivekit/simul_whisper/mlx/decoder_state.py new file mode 100644 index 0000000..332bc8c --- /dev/null +++ b/whisperlivekit/simul_whisper/mlx/decoder_state.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import mlx.core as mx +import numpy as np + + +@dataclass +class MLXDecoderState: + """ + mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer, + where each element is a tuple of mx.arrays. + """ + + kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None + + tokenizer: Any = None + detected_language: Optional[str] = None + reset_tokenizer_to_auto_next_call: bool = False + + tokens: List[mx.array] = field(default_factory=list) + initial_tokens: Optional[mx.array] = None + initial_token_length: int = 0 + sot_index: int = 0 + align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict) + num_align_heads: int = 0 + segments: List[np.ndarray] = field(default_factory=list) + + context: Any = None + + pending_incomplete_tokens: List[int] = field(default_factory=list) + + global_time_offset: float = 0.0 + cumulative_time_offset: float = 0.0 + first_timestamp: Optional[float] = None + last_attend_frame: int = 0 + + speaker: int = -1 + log_segments: int = 0 + cif_weights: Optional[mx.array] = None + always_fire: bool = False + never_fire: bool = False + + suppress_tokens: Optional[Tuple[int, ...]] = None + + token_decoder: Any = None + decoder_type: str = "greedy" + + inference: Any = None + + def clean_cache(self): + self.kv_cache = None + if self.decoder_type == "beam" and self.inference is not None: + self.inference.kv_cache = None + if self.token_decoder is not None: + self.token_decoder.reset() + + def reset(self, rewind_threshold: int = 200): + self.last_attend_frame = -rewind_threshold + self.cumulative_time_offset = 0.0 + self.pending_incomplete_tokens = [] + self.log_segments += 1 + + def full_reset(self, rewind_threshold: int = 200): + """ + Full reset including audio segments and tokens. + + Args: + rewind_threshold: Value for resetting last_attend_frame + """ + self.reset(rewind_threshold) + self.segments = [] + self.tokens = [] + self.kv_cache = None + self.first_timestamp = None +