mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
|
|
|
|
@dataclass
|
|
class MLXDecoderState:
|
|
"""
|
|
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
|
|
where each element is a tuple of mx.arrays.
|
|
"""
|
|
|
|
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
|
|
|
tokenizer: Any = None
|
|
detected_language: Optional[str] = None
|
|
reset_tokenizer_to_auto_next_call: bool = False
|
|
|
|
tokens: List[mx.array] = field(default_factory=list)
|
|
initial_tokens: Optional[mx.array] = None
|
|
initial_token_length: int = 0
|
|
sot_index: int = 0
|
|
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
|
num_align_heads: int = 0
|
|
segments: List[np.ndarray] = field(default_factory=list)
|
|
|
|
context: Any = None
|
|
|
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
|
|
|
global_time_offset: float = 0.0
|
|
cumulative_time_offset: float = 0.0
|
|
first_timestamp: Optional[float] = None
|
|
last_attend_frame: int = 0
|
|
|
|
speaker: int = -1
|
|
log_segments: int = 0
|
|
cif_weights: Optional[mx.array] = None
|
|
always_fire: bool = False
|
|
never_fire: bool = False
|
|
|
|
suppress_tokens: Optional[Tuple[int, ...]] = None
|
|
|
|
token_decoder: Any = None
|
|
decoder_type: str = "greedy"
|
|
|
|
inference: Any = None
|
|
|
|
def clean_cache(self):
|
|
self.kv_cache = None
|
|
if self.decoder_type == "beam" and self.inference is not None:
|
|
self.inference.kv_cache = None
|
|
if self.token_decoder is not None:
|
|
self.token_decoder.reset()
|
|
|
|
def reset(self, rewind_threshold: int = 200):
|
|
self.last_attend_frame = -rewind_threshold
|
|
self.cumulative_time_offset = 0.0
|
|
self.pending_incomplete_tokens = []
|
|
self.log_segments += 1
|
|
|
|
def full_reset(self, rewind_threshold: int = 200):
|
|
"""
|
|
Full reset including audio segments and tokens.
|
|
|
|
Args:
|
|
rewind_threshold: Value for resetting last_attend_frame
|
|
"""
|
|
self.reset(rewind_threshold)
|
|
self.segments = []
|
|
self.tokens = []
|
|
self.kv_cache = None
|
|
self.first_timestamp = None
|
|
|