diff --git a/whisperlivekit/voxtral_mlx/__init__.py b/whisperlivekit/voxtral_mlx/__init__.py new file mode 100644 index 0000000..d008c70 --- /dev/null +++ b/whisperlivekit/voxtral_mlx/__init__.py @@ -0,0 +1,6 @@ +"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit.""" + +from .loader import load_voxtral_model +from .model import VoxtralMLXModel + +__all__ = ["load_voxtral_model", "VoxtralMLXModel"] diff --git a/whisperlivekit/voxtral_mlx/loader.py b/whisperlivekit/voxtral_mlx/loader.py new file mode 100644 index 0000000..486bd71 --- /dev/null +++ b/whisperlivekit/voxtral_mlx/loader.py @@ -0,0 +1,282 @@ +""" +Model weight loading for the MLX Voxtral Realtime backend. + +Supports two on-disk formats: + 1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load, + with optional quantisation metadata. + 2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``): + requires weight renaming and conv-weight transposition. + +The public entry point is :func:`load_voxtral_model` which returns the +model, tokenizer, and raw config dict. +""" + +import json +import logging +import re +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download + +from .model import VoxtralMLXModel + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit" + +# --------------------------------------------------------------------------- +# Downloading +# --------------------------------------------------------------------------- + +_ALLOWED_PATTERNS = [ + "consolidated.safetensors", + "model*.safetensors", + "model.safetensors.index.json", + "params.json", + "config.json", + "tekken.json", +] + + +def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path: + """Download model files from HuggingFace Hub and return the local path.""" + return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS)) + + +# --------------------------------------------------------------------------- +# Weight name remapping (Mistral → our naming) +# --------------------------------------------------------------------------- + +_NAME_RULES: list[tuple[str, str]] = [ + # Encoder convolutions + (r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"), + (r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"), + # Encoder transformer blocks + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)", + r"encoder.blocks.\1.self_attn.q_proj.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)", + r"encoder.blocks.\1.self_attn.k_proj.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)", + r"encoder.blocks.\1.self_attn.v_proj.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)", + r"encoder.blocks.\1.self_attn.out_proj.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)", + r"encoder.blocks.\1.pre_attn_norm.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)", + r"encoder.blocks.\1.ffn.gate.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)", + r"encoder.blocks.\1.ffn.down.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)", + r"encoder.blocks.\1.ffn.up.\2"), + (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)", + r"encoder.blocks.\1.pre_ffn_norm.\2"), + (r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"), + # Adapter + (r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"), + (r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"), + # Decoder embedding + (r"tok_embeddings\.weight", r"decoder.token_embedding.weight"), + # Decoder blocks + (r"layers\.(\d+)\.attention\.wq\.weight", + r"decoder.blocks.\1.self_attn.q_proj.weight"), + (r"layers\.(\d+)\.attention\.wk\.weight", + r"decoder.blocks.\1.self_attn.k_proj.weight"), + (r"layers\.(\d+)\.attention\.wv\.weight", + r"decoder.blocks.\1.self_attn.v_proj.weight"), + (r"layers\.(\d+)\.attention\.wo\.weight", + r"decoder.blocks.\1.self_attn.out_proj.weight"), + (r"layers\.(\d+)\.attention_norm\.weight", + r"decoder.blocks.\1.pre_attn_norm.weight"), + (r"layers\.(\d+)\.feed_forward\.w1\.weight", + r"decoder.blocks.\1.ffn.gate.weight"), + (r"layers\.(\d+)\.feed_forward\.w2\.weight", + r"decoder.blocks.\1.ffn.down.weight"), + (r"layers\.(\d+)\.feed_forward\.w3\.weight", + r"decoder.blocks.\1.ffn.up.weight"), + (r"layers\.(\d+)\.ffn_norm\.weight", + r"decoder.blocks.\1.pre_ffn_norm.weight"), + (r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight", + r"decoder.blocks.\1.adaptive_scale.proj_in.weight"), + (r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight", + r"decoder.blocks.\1.adaptive_scale.proj_out.weight"), + # Decoder final norm + (r"norm\.weight", r"decoder.final_norm.weight"), +] + +_PREFIX_STRIP = re.compile( + r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\." +) + + +def _translate_weight_name(name: str) -> str | None: + name = _PREFIX_STRIP.sub("", name) + for pattern, replacement in _NAME_RULES: + result, n = re.subn(f"^{pattern}$", replacement, name) + if n: + return result + return None + + +def _is_conv_weight(name: str) -> bool: + return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name + + +# --------------------------------------------------------------------------- +# Converted-format weight remapping (voxmlx names → our names) +# --------------------------------------------------------------------------- + +_CONVERTED_RULES: list[tuple[str, str]] = [ + # Adapter + (r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"), + (r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"), + # Encoder transformer blocks + (r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"), + (r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"), + (r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"), + (r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"), + (r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"), + (r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"), + (r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"), + # Decoder embedding + (r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"), + # Decoder blocks + (r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"), + (r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"), + (r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"), + (r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"), + (r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"), + (r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"), + (r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)", + r"decoder.blocks.\1.adaptive_scale.proj_in.\2"), + (r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)", + r"decoder.blocks.\1.adaptive_scale.proj_out.\2"), + (r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"), +] + +# Also remap o_proj → out_proj in both encoder and decoder +_POST_RENAME = [ + (r"\.o_proj\.", r".out_proj."), +] + + +def _remap_converted_name(name: str) -> str: + """Translate a converted-format weight name to our naming convention.""" + for pattern, replacement in _CONVERTED_RULES: + result, n = re.subn(f"^{pattern}$", replacement, name) + if n: + name = result + break + for pattern, replacement in _POST_RENAME: + name = re.sub(pattern, replacement, name) + return name + + +# --------------------------------------------------------------------------- +# Loading strategies +# --------------------------------------------------------------------------- + +def _has_converted_layout(path: Path) -> bool: + return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists() + + +def _load_converted_weights(path: Path): + with open(path / "config.json") as f: + config = json.load(f) + + model = VoxtralMLXModel(config) + + quant = config.get("quantization") + if quant is not None: + gs = quant["group_size"] + nn.quantize( + model, + group_size=gs, + bits=quant["bits"], + class_predicate=lambda _p, m: ( + hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0 + ), + ) + + index_file = path / "model.safetensors.index.json" + if index_file.exists(): + with open(index_file) as f: + shard_map = json.load(f) + shard_files = sorted(set(shard_map["weight_map"].values())) + weights = {} + for sf in shard_files: + weights.update(mx.load(str(path / sf))) + else: + weights = mx.load(str(path / "model.safetensors")) + + remapped = {_remap_converted_name(k): v for k, v in weights.items()} + model.load_weights(list(remapped.items())) + mx.eval(model.parameters()) + return model, config + + +def _load_original_weights(path: Path): + with open(path / "params.json") as f: + config = json.load(f) + + model = VoxtralMLXModel(config) + + raw = mx.load(str(path / "consolidated.safetensors")) + mapped: dict[str, mx.array] = {} + skipped: list[str] = [] + + for name, tensor in raw.items(): + if name == "output.weight": + continue + new_name = _translate_weight_name(name) + if new_name is None: + skipped.append(name) + continue + # Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in] + if _is_conv_weight(new_name): + tensor = mx.swapaxes(tensor, 1, 2) + mapped[new_name] = tensor + + if skipped: + logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5]) + + model.load_weights(list(mapped.items())) + mx.eval(model.parameters()) + return model, config + + +# --------------------------------------------------------------------------- +# Tokenizer +# --------------------------------------------------------------------------- + +def _load_tokenizer(model_dir: Path): + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + return Tekkenizer.from_file(str(model_dir / "tekken.json")) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID): + """Load a Voxtral Realtime model and its tokenizer. + + Args: + path_or_id: Local directory path **or** a HuggingFace model ID. + + Returns: + ``(model, tokenizer, config)`` + """ + p = Path(path_or_id) + if not p.exists(): + p = download_weights(path_or_id) + + if _has_converted_layout(p): + model, config = _load_converted_weights(p) + else: + model, config = _load_original_weights(p) + + tokenizer = _load_tokenizer(p) + logger.info("Voxtral MLX model loaded from %s", p) + return model, tokenizer, config diff --git a/whisperlivekit/voxtral_mlx/model.py b/whisperlivekit/voxtral_mlx/model.py new file mode 100644 index 0000000..0a637f8 --- /dev/null +++ b/whisperlivekit/voxtral_mlx/model.py @@ -0,0 +1,534 @@ +""" +Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model. + +Architecture: + audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits + with DelayEmbedding providing time-conditioning to the decoder. + +The model supports both batch inference (full audio) and incremental streaming +(one chunk at a time with cached encoder/decoder state). +""" + +import math + +import mlx.core as mx +import mlx.nn as nn + + +# --------------------------------------------------------------------------- +# KV Cache +# --------------------------------------------------------------------------- + + +class SlidingKVCache: + """Bounded key-value cache with rotating buffer for sliding-window attention. + + Uses in-place writes for single-token autoregressive steps and + concatenation for multi-token prefills. Pre-allocates in blocks of + ``alloc_step`` entries to reduce repeated allocation. + """ + + alloc_step = 256 + + def __init__(self, capacity: int): + self.capacity = capacity + self.keys = None + self.values = None + self._offset = 0 + self._write_idx = 0 + + @property + def offset(self) -> int: + return self._offset + + # -- helpers -- + + def _reorder(self, buf): + """Return *buf* in temporal order (unwrap the circular buffer).""" + if self._write_idx == buf.shape[2]: + return buf + if self._write_idx < self._offset: + return mx.concatenate( + [buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]], + axis=2, + ) + return buf[..., : self._write_idx, :] + + def _drop_oldest(self, buf, n_drop, tail=None): + parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf] + if tail is not None: + parts.append(tail) + return mx.concatenate(parts, axis=2) + + # -- update strategies -- + + def _append_concat(self, k, v): + """Multi-token update via concatenation (used during prefill).""" + if self.keys is None: + self.keys, self.values = k, v + else: + self.keys = self._reorder(self.keys) + self.values = self._reorder(self.values) + self._write_idx = self.keys.shape[2] + overflow = self._write_idx - self.capacity + 1 + self.keys = self._drop_oldest(self.keys, overflow, k) + self.values = self._drop_oldest(self.values, overflow, v) + self._offset += k.shape[2] + self._write_idx = self.keys.shape[2] + return self.keys, self.values + + def _write_inplace(self, k, v): + """Single-token update via in-place write (autoregressive step).""" + B, n_heads, S, dim_k = k.shape + dim_v = v.shape[3] + prev = self._offset + + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity + ): + n_new = min(self.alloc_step, self.capacity - prev) + fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype) + fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, fresh_k], axis=2) + self.values = mx.concatenate([self.values, fresh_v], axis=2) + else: + self.keys, self.values = fresh_k, fresh_v + self._write_idx = prev + + overflow = self.keys.shape[2] - self.capacity + if overflow > 0: + self.keys = self._drop_oldest(self.keys, overflow) + self.values = self._drop_oldest(self.values, overflow) + self._write_idx = self.capacity + + if self._write_idx == self.capacity: + self._write_idx = 0 + + self.keys[..., self._write_idx : self._write_idx + S, :] = k + self.values[..., self._write_idx : self._write_idx + S, :] = v + self._offset += S + self._write_idx += S + + if self._offset < self.capacity: + return ( + self.keys[..., : self._offset, :], + self.values[..., : self._offset, :], + ) + return self.keys, self.values + + # -- public API -- + + def update_and_fetch(self, k, v): + if k.shape[2] == 1: + return self._write_inplace(k, v) + return self._append_concat(k, v) + + +# --------------------------------------------------------------------------- +# Encoder components +# --------------------------------------------------------------------------- + + +class CausalConv(nn.Module): + """1-D causal convolution (left-padded so no future leakage).""" + + def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1): + super().__init__() + self.stride = stride + self.kernel = kernel + self.left_pad = kernel - stride + self.weight = mx.zeros((channels_out, kernel, channels_in)) + self.bias = mx.zeros((channels_out,)) + + def __call__(self, x: mx.array) -> mx.array: + if self.left_pad > 0: + x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)]) + return mx.conv1d(x, self.weight, stride=self.stride) + self.bias + + +class _EncoderSelfAttention(nn.Module): + def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float): + super().__init__() + self.n_heads = n_heads + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True) + self.rope_theta = rope_theta + + def __call__(self, x, mask, cache=None): + B, L, _ = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) + + pos = cache.offset if cache is not None else 0 + q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos) + k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos) + + if cache is not None: + k, v = cache.update_and_fetch(k, v) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1)) + + +class _EncoderFFN(nn.Module): + """SwiGLU feed-forward for encoder layers.""" + + def __init__(self, dim: int, hidden: int): + super().__init__() + self.gate = nn.Linear(dim, hidden, bias=False) + self.up = nn.Linear(dim, hidden, bias=False) + self.down = nn.Linear(hidden, dim, bias=True) + + def __call__(self, x): + return self.down(nn.silu(self.gate(x)) * self.up(x)) + + +class _EncoderBlock(nn.Module): + def __init__(self, dim, n_heads, head_dim, hidden, rope_theta): + super().__init__() + self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5) + self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta) + self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5) + self.ffn = _EncoderFFN(dim, hidden) + + def __call__(self, x, mask, cache=None): + x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache) + x = x + self.ffn(self.pre_ffn_norm(x)) + return x + + +class StreamingEncoder(nn.Module): + """Causal Whisper-style encoder with two causal convolutions followed by + a stack of transformer blocks. Supports both full-sequence and + incremental (streaming) forward passes.""" + + def __init__( + self, + mel_channels: int = 128, + dim: int = 1280, + n_layers: int = 32, + n_heads: int = 32, + head_dim: int = 64, + hidden_dim: int = 5120, + rope_theta: float = 1e6, + sliding_window: int = 750, + ): + super().__init__() + self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1) + self.conv2 = CausalConv(dim, dim, kernel=3, stride=2) + self.blocks = [ + _EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta) + for _ in range(n_layers) + ] + self.final_norm = nn.RMSNorm(dim, eps=1e-5) + self.sliding_window = sliding_window + + # -- full-sequence -- + + def _apply_convs(self, mel: mx.array) -> mx.array: + x = mel.T[None, :, :] # [1, T, mel_channels] + x = nn.gelu(self.conv1(x)) + x = nn.gelu(self.conv2(x)) + return x + + def forward(self, mel: mx.array) -> mx.array: + x = self._apply_convs(mel.astype(self.conv1.weight.dtype)) + for blk in self.blocks: + x = blk(x, mask="causal") + return self.final_norm(x) + + # -- incremental (streaming) -- + + def forward_conv_incremental(self, x_in, tail1, tail2): + """Process new mel frames through the two causal convs using cached tails. + + Args: + x_in: [1, N, mel_channels] + tail1: [1, pad1, mel_channels] or None (first call) + tail2: [1, pad2, dim] or None (first call) + + Returns: + (out, new_tail1, new_tail2) + """ + # Conv1 (kernel=3, stride=1 → left_pad=2) + if tail1 is not None: + c1_in = mx.concatenate([tail1, x_in], axis=1) + else: + c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)]) + new_tail1 = x_in[:, -self.conv1.left_pad :, :] + c1_out = nn.gelu( + mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias + ) + + # Conv2 (kernel=3, stride=2 → left_pad=1) + if tail2 is not None: + c2_in = mx.concatenate([tail2, c1_out], axis=1) + else: + c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)]) + new_tail2 = c1_out[:, -self.conv2.left_pad :, :] + c2_out = nn.gelu( + mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias + ) + + return c2_out, new_tail1, new_tail2 + + def forward_transformer_incremental(self, x, cache_list): + """Run transformer blocks with per-layer KV caches.""" + for i, blk in enumerate(self.blocks): + x = blk(x, mask="causal", cache=cache_list[i]) + return self.final_norm(x) + + +# --------------------------------------------------------------------------- +# Decoder components +# --------------------------------------------------------------------------- + + +class _DecoderAttention(nn.Module): + """Grouped-query attention for the text decoder.""" + + def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.rope_theta = rope_theta + + def __call__(self, x, mask=None, cache=None): + B, L, _ = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + pos = cache.offset if cache is not None else 0 + q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos) + k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos) + + if cache is not None: + k, v = cache.update_and_fetch(k, v) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1)) + + +class _DecoderFFN(nn.Module): + """SwiGLU feed-forward for decoder layers.""" + + def __init__(self, dim, hidden): + super().__init__() + self.gate = nn.Linear(dim, hidden, bias=False) + self.up = nn.Linear(dim, hidden, bias=False) + self.down = nn.Linear(hidden, dim, bias=False) + + def __call__(self, x): + return self.down(nn.silu(self.gate(x)) * self.up(x)) + + +class AdaptiveScaling(nn.Module): + """Small MLP that produces a multiplicative scale from the delay embedding, + used to condition the FFN on the streaming delay.""" + + def __init__(self, dim, bottleneck): + super().__init__() + self.proj_in = nn.Linear(dim, bottleneck, bias=False) + self.proj_out = nn.Linear(bottleneck, dim, bias=False) + + def __call__(self, cond): + return self.proj_out(nn.gelu(self.proj_in(cond))) + + +class _DecoderBlock(nn.Module): + def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim): + super().__init__() + self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5) + self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta) + self.adaptive_scale = AdaptiveScaling(dim, cond_dim) + self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5) + self.ffn = _DecoderFFN(dim, hidden) + + def __call__(self, x, delay_cond, mask=None, cache=None): + x = x + self.self_attn(self.pre_attn_norm(x), mask, cache) + scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond)) + x = x + self.ffn(scaled) + return x + + +class TextDecoder(nn.Module): + """Mistral-style causal language model with adaptive time-conditioning.""" + + def __init__( + self, + dim: int = 3072, + n_layers: int = 26, + n_heads: int = 32, + n_kv_heads: int = 8, + head_dim: int = 128, + hidden_dim: int = 9216, + vocab_size: int = 131072, + rope_theta: float = 1e6, + cond_dim: int = 32, + ): + super().__init__() + self.token_embedding = nn.Embedding(vocab_size, dim) + self.blocks = [ + _DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim) + for _ in range(n_layers) + ] + self.final_norm = nn.RMSNorm(dim, eps=1e-5) + + def embed(self, token_ids: mx.array) -> mx.array: + return self.token_embedding(token_ids) + + def __call__(self, x, delay_cond, mask=None, cache=None): + delay_cond = delay_cond.astype(x.dtype) + for i, blk in enumerate(self.blocks): + blk_cache = cache[i] if cache is not None else None + x = blk(x, delay_cond, mask, blk_cache) + x = self.final_norm(x) + return self.token_embedding.as_linear(x) + + +# --------------------------------------------------------------------------- +# Adapter & embeddings +# --------------------------------------------------------------------------- + + +class EncoderToDecoderAdapter(nn.Module): + """Two-layer projection from encoder space to decoder space.""" + + def __init__(self, enc_dim: int, dec_dim: int): + super().__init__() + self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False) + self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False) + + def __call__(self, x): + return self.linear2(nn.gelu(self.linear1(x))) + + +class DelayEmbedding(nn.Module): + """Sinusoidal embedding that encodes the streaming delay as a conditioning + vector for the decoder's adaptive scaling.""" + + def __init__(self, dim: int = 3072, theta: float = 10000.0): + super().__init__() + self.dim = dim + half = dim // 2 + freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half) + self._freqs = freqs + + def __call__(self, delay: mx.array) -> mx.array: + t = delay.reshape(-1, 1).astype(mx.float32) + angles = t * self._freqs + return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1) + + +# --------------------------------------------------------------------------- +# Top-level model +# --------------------------------------------------------------------------- + + +class VoxtralMLXModel(nn.Module): + """Top-level Voxtral Realtime model wiring encoder, adapter, and decoder.""" + + def __init__(self, config: dict): + super().__init__() + + enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"] + audio_cfg = enc_cfg["audio_encoding_args"] + ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"] + + self.encoder = StreamingEncoder( + mel_channels=audio_cfg["num_mel_bins"], + dim=enc_cfg["dim"], + n_layers=enc_cfg["n_layers"], + n_heads=enc_cfg["n_heads"], + head_dim=enc_cfg["head_dim"], + hidden_dim=enc_cfg["hidden_dim"], + rope_theta=enc_cfg["rope_theta"], + sliding_window=enc_cfg["sliding_window"], + ) + + adapter_input_dim = enc_cfg["dim"] * ds_factor + decoder_dim = config["dim"] + cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32) + + self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim) + + self.decoder = TextDecoder( + dim=decoder_dim, + n_layers=config["n_layers"], + n_heads=config["n_heads"], + n_kv_heads=config["n_kv_heads"], + head_dim=config["head_dim"], + hidden_dim=config["hidden_dim"], + vocab_size=config["vocab_size"], + rope_theta=config["rope_theta"], + cond_dim=cond_bottleneck, + ) + + self.delay_embedding = DelayEmbedding(dim=decoder_dim) + self.ds_factor = ds_factor + + # -- batch encode -- + + def encode(self, mel: mx.array) -> mx.array: + T = mel.shape[1] + if T % 2 != 0: + mel = mel[:, 1:] + + h = self.encoder.forward(mel) # [1, T/2, enc_dim] + h = h[0] + + n = h.shape[0] + trim = n % self.ds_factor + if trim: + h = h[trim:] + n = h.shape[0] + + h = h.reshape(n // self.ds_factor, -1) + return self.adapter(h) + + # -- incremental encode -- + + def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder): + """Incrementally encode new mel frames. + + Returns: + (audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder) + """ + x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype) + + x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2) + + if enc_cache is None: + enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))] + + x = self.encoder.forward_transformer_incremental(x, enc_cache) + x = x[0] # [N, enc_dim] + + if ds_remainder is not None: + x = mx.concatenate([ds_remainder, x]) + + n_full = (x.shape[0] // self.ds_factor) * self.ds_factor + if n_full == 0: + return None, conv_tail1, conv_tail2, enc_cache, x + + leftover = x[n_full:] if x.shape[0] > n_full else None + x = x[:n_full].reshape(n_full // self.ds_factor, -1) + return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover + + # -- decode -- + + def decode(self, embeddings, delay_cond, mask=None, cache=None): + return self.decoder(embeddings, delay_cond, mask, cache) diff --git a/whisperlivekit/voxtral_mlx/spectrogram.py b/whisperlivekit/voxtral_mlx/spectrogram.py new file mode 100644 index 0000000..0fdf463 --- /dev/null +++ b/whisperlivekit/voxtral_mlx/spectrogram.py @@ -0,0 +1,202 @@ +""" +Mel spectrogram computation for Voxtral Realtime. + +Provides both a full-audio function and an incremental streaming variant +that maintains overlap state between calls. The DFT is computed via +matrix multiplication in MLX — no external FFT dependency required. +""" + +import math + +import mlx.core as mx +import numpy as np + +# Audio / mel constants matching the Voxtral Realtime model expectations. +SAMPLE_RATE = 16_000 +WINDOW_SIZE = 400 # n_fft +HOP = 160 +MEL_BANDS = 128 +MEL_MAX = 1.5 # global log-mel normalisation ceiling +# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4) +SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms + +# Padding tokens used by the model prompt structure. +LEFT_PAD_TOKENS = 32 +RIGHT_PAD_TOKENS = 17 + + +# --------------------------------------------------------------------------- +# Slaney mel filterbank +# --------------------------------------------------------------------------- + +def _build_slaney_filterbank( + sr: int = SAMPLE_RATE, + n_fft: int = WINDOW_SIZE, + n_mels: int = MEL_BANDS, + lo_hz: float = 0.0, + hi_hz: float = 8000.0, +) -> np.ndarray: + """Compute a Slaney-normalised triangular mel filterbank. + + Returns an array of shape ``[n_mels, n_fft//2 + 1]``. + """ + + def _hz2mel(f): + threshold = 1000.0 + base_mel = 15.0 + log_coeff = 27.0 / np.log(6.4) + mel = 3.0 * f / 200.0 + if isinstance(f, np.ndarray): + above = f >= threshold + mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff + elif f >= threshold: + mel = base_mel + np.log(f / threshold) * log_coeff + return mel + + def _mel2hz(m): + threshold = 1000.0 + base_mel = 15.0 + log_coeff = np.log(6.4) / 27.0 + hz = 200.0 * m / 3.0 + above = m >= base_mel + hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel)) + return hz + + n_bins = n_fft // 2 + 1 + fft_hz = np.linspace(0, sr / 2, n_bins) + mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz) + mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2) + hz_pts = _mel2hz(mel_pts) + diffs = np.diff(hz_pts) + + slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1) + rising = -slopes[:, :-2] / diffs[:-1] + falling = slopes[:, 2:] / diffs[1:] + fb = np.maximum(0.0, np.minimum(rising, falling)) + + # Slaney area normalisation + widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels]) + fb *= np.expand_dims(widths, 0) + return fb.T.astype(np.float32) + + +_CACHED_FILTERS: mx.array | None = None + + +def _mel_filters() -> mx.array: + global _CACHED_FILTERS + if _CACHED_FILTERS is None: + _CACHED_FILTERS = mx.array(_build_slaney_filterbank()) + return _CACHED_FILTERS + + +# --------------------------------------------------------------------------- +# DFT helpers +# --------------------------------------------------------------------------- + +def _hann_window() -> mx.array: + return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32)) + + +def _dft_matrices(): + """Pre-compute the real / imaginary DFT basis matrices.""" + n_bins = WINDOW_SIZE // 2 + 1 + k = mx.arange(n_bins, dtype=mx.float32)[:, None] + n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :] + phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE + return mx.cos(phase), mx.sin(phase) + + +def _stft_frames(audio: mx.array, window: mx.array) -> mx.array: + """Frame *audio* using the Hann window and compute power spectrogram.""" + n_bins = WINDOW_SIZE // 2 + 1 + n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP + if n_frames <= 0: + return mx.zeros((0, n_bins)) + + offsets = (mx.arange(n_frames) * HOP)[:, None] + indices = offsets + mx.arange(WINDOW_SIZE)[None, :] + windowed = audio[indices] * window[None, :] + + dft_re, dft_im = _dft_matrices() + real_part = windowed @ dft_re.T + imag_part = windowed @ dft_im.T + return real_part ** 2 + imag_part ** 2 + + +def _apply_mel_and_log(power: mx.array) -> mx.array: + """Convert a power spectrogram to log-mel and normalise.""" + mel = power @ _mel_filters().T + log_mel = mx.log10(mx.maximum(mel, 1e-10)) + log_mel = mx.maximum(log_mel, MEL_MAX - 8.0) + return (log_mel + 4.0) / 4.0 + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def compute_mel(audio: np.ndarray) -> mx.array: + """Compute log-mel spectrogram for a complete audio signal. + + Args: + audio: 1-D float32 numpy array at ``SAMPLE_RATE``. + + Returns: + ``[MEL_BANDS, T]`` MLX array. + """ + x = mx.array(audio) + pad = WINDOW_SIZE // 2 + x = mx.pad(x, [(pad, pad)]) + window = _hann_window() + + power = _stft_frames(x, window) + # Drop last frame to match reference STFT behaviour + power = power[:-1] + return _apply_mel_and_log(power).T + + +def compute_mel_streaming( + chunk: np.ndarray, + overlap: np.ndarray | None, +) -> tuple[mx.array, np.ndarray]: + """Incrementally compute log-mel for a new audio chunk. + + Args: + chunk: New audio samples (float32 numpy). + overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the + previous call, or *None* on the first call (uses zero-padding). + + Returns: + ``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and + *new_overlap* is the 240-sample tail for the next call. + """ + tail_len = WINDOW_SIZE - HOP # 240 + + if overlap is not None: + combined = np.concatenate([overlap, chunk]) + else: + combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk]) + + new_overlap = combined[-tail_len:].copy() + + x = mx.array(combined) + window = _hann_window() + power = _stft_frames(x, window) + + if power.shape[0] == 0: + return mx.zeros((MEL_BANDS, 0)), new_overlap + + return _apply_mel_and_log(power).T, new_overlap + + +def pad_audio( + audio: np.ndarray, + n_left: int = LEFT_PAD_TOKENS, + n_right: int = RIGHT_PAD_TOKENS, +) -> np.ndarray: + """Pad audio with silence for batch (non-streaming) inference.""" + left = n_left * SAMPLES_PER_TOKEN + align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN + right = align + n_right * SAMPLES_PER_TOKEN + return np.pad(audio, (left, right)) diff --git a/whisperlivekit/voxtral_mlx_asr.py b/whisperlivekit/voxtral_mlx_asr.py new file mode 100644 index 0000000..4c62f80 --- /dev/null +++ b/whisperlivekit/voxtral_mlx_asr.py @@ -0,0 +1,521 @@ +""" +Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit. + +Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor`` +(streaming processor) that plug into WhisperLiveKit's audio processing +pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc. + +Unlike the HuggingFace backend, this runs the full inference loop in-process +(no background thread / queue) — MLX operations on Apple Silicon are fast +enough to run synchronously inside ``asyncio.to_thread(process_iter)``. +""" + +import logging +import sys +import time +from typing import List, Optional, Tuple + +import mlx.core as mx +import numpy as np +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + +from whisperlivekit.timed_objects import ASRToken, Transcript +from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID +from whisperlivekit.voxtral_mlx.model import SlidingKVCache +from whisperlivekit.voxtral_mlx.spectrogram import ( + SAMPLES_PER_TOKEN, + LEFT_PAD_TOKENS, + RIGHT_PAD_TOKENS, + compute_mel_streaming, +) + +logger = logging.getLogger(__name__) + +# Decoder sliding-window size (matches the model's training configuration). +_DECODER_WINDOW = 8192 + + +def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6): + """Build the prompt token sequence and return ``(token_ids, n_delay)``.""" + pad_id = tokenizer.get_special_token("[STREAMING_PAD]") + ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay) + return ids, n_delay + + +# --------------------------------------------------------------------------- +# Model holder +# --------------------------------------------------------------------------- + + +class VoxtralMLXASR: + """Lightweight model holder — loads the MLX Voxtral model once and keeps + it alive for the lifetime of the server.""" + + sep = " " + SAMPLING_RATE = 16_000 + + def __init__(self, logfile=sys.stderr, **kwargs): + self.logfile = logfile + self.transcribe_kargs = {} + + lan = kwargs.get("lan", "auto") + self.original_language = None if lan == "auto" else lan + + model_path = kwargs.get("model_dir") or kwargs.get("model_path") + if not model_path: + model_size = kwargs.get("model_size", "") + if model_size and ("/" in model_size or model_size.startswith(".")): + model_path = model_size + else: + model_path = DEFAULT_MODEL_ID + + t0 = time.time() + logger.info("Loading Voxtral MLX model '%s' ...", model_path) + self.model, self.tokenizer, self.config = load_voxtral_model(model_path) + logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0) + + self.backend_choice = "voxtral-mlx" + + def transcribe(self, audio): + pass # all work happens in the online processor + + +# --------------------------------------------------------------------------- +# Online processor +# --------------------------------------------------------------------------- + + +class VoxtralMLXOnlineProcessor: + """Streaming processor that incrementally encodes audio and decodes text + using the MLX Voxtral model. + + Lifecycle (called by ``AudioProcessor.transcription_processor``): + + insert_audio_chunk(pcm, time) → process_iter() → get_buffer() + ... repeat ... + start_silence() / end_silence() + finish() + """ + + SAMPLING_RATE = 16_000 + + def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr): + self.asr = asr + self.logfile = logfile + self.end = 0.0 + self.buffer: list = [] + self.audio_buffer = np.array([], dtype=np.float32) + + self._model = asr.model + self._tokenizer = asr.tokenizer + + # Pre-compute prompt tokens and delay conditioning (constant across utterances). + self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer) + self._prefix_len = len(self._prompt_ids) + + self._delay_cond = self._model.delay_embedding( + mx.array([self._n_delay], dtype=mx.float32) + ) + mx.eval(self._delay_cond) + + self._prompt_embeds = self._model.decoder.embed( + mx.array([self._prompt_ids]) + )[0] # [prefix_len, dim] + mx.eval(self._prompt_embeds) + + self._eos_id = self._tokenizer.eos_id + self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE + # The streaming model has an inherent delay: text for audio at position P + # is generated at decoder position P + n_delay. Compensate timestamps. + self._delay_secs = self._n_delay * self._secs_per_token + + self._reset_state() + + # -- state management -- + + def _reset_state(self): + """Reset all incremental state for a fresh utterance.""" + # Audio accumulation + self._pending = np.zeros(0, dtype=np.float32) + # Mel overlap + self._mel_overlap: np.ndarray | None = None + # Encoder incremental state + self._conv_tail1 = None + self._conv_tail2 = None + self._enc_cache = None + self._ds_remainder = None + # Audio embeddings not yet decoded + self._audio_embeds: mx.array | None = None + # Decoder state + self._dec_cache: list[SlidingKVCache] | None = None + self._last_token: mx.array | None = None + # Bookkeeping + self._samples_encoded = 0 + self._positions_decoded = 0 + self._prefilled = False + self._first_chunk = True + # Text state + self._full_text = "" + self._n_text_tokens = 0 + self._n_committed_words = 0 + self._time_offset = 0.0 + # Per-word audio position tracking: decoder position (relative to prefix) + # where each word in _full_text started and ended + self._word_audio_starts: list[int] = [] # audio pos where word i started + self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token + self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token + + # -- audio ingestion -- + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self._pending = np.append(self._pending, audio) + self.audio_buffer = self._pending + + # -- core processing -- + + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + try: + return self._step(is_last) + except Exception as e: + logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True) + return [], self.end + + def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]: + # 1. Encode any new audio + self._encode_pending() + + if self._audio_embeds is None: + return [], self.end + + # 2. Compute how many positions we can safely decode + total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN + n_available = self._audio_embeds.shape[0] + n_decodable = min(n_available, total_safe - self._positions_decoded) + + if n_decodable <= 0: + return [], self.end + + # 3. Prefill if needed + if not self._prefilled: + if self._positions_decoded + n_available < self._prefix_len: + return [], self.end + self._do_prefill() + # Re-check after consuming prefix embeddings + n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0 + n_decodable = min(n_available, total_safe - self._positions_decoded) + + if n_decodable <= 0 or self._audio_embeds is None: + return [], self.end + + # 4. Decode available positions + hit_eos = self._decode_positions(n_decodable) + + if hit_eos: + # Flush words, reset for next utterance + words = self._flush_all_words() + logger.debug( + "[voxtral-mlx] EOS hit during stream: flushed %d words, " + "samples_encoded=%d (%.2fs), text='%s'", + len(words), self._samples_encoded, + self._samples_encoded / self.SAMPLING_RATE, + self._full_text[-60:] if self._full_text else "", + ) + saved_offset = self._time_offset + self._reset_state() + self._time_offset = saved_offset + return words, self.end + + # 5. Extract committed words (all but the last, which may still grow) + return self._extract_committed_words(), self.end + + def _encode_pending(self): + """Feed pending audio through the incremental encoder.""" + available = len(self._pending) + if available < SAMPLES_PER_TOKEN: + return + + if self._first_chunk: + # First chunk: prepend silence for left-padding + n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN + left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32) + chunk = np.concatenate([left_pad, self._pending[:n_take]]) + self._pending = self._pending[n_take:] + self._samples_encoded += n_take + self._first_chunk = False + else: + n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN + chunk = self._pending[:n_take] + self._pending = self._pending[n_take:] + self._samples_encoded += n_take + + mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap) + + embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = ( + self._model.encode_incremental( + mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder + ) + ) + + if embeds is not None: + mx.eval(embeds) + if self._audio_embeds is not None: + self._audio_embeds = mx.concatenate([self._audio_embeds, embeds]) + else: + self._audio_embeds = embeds + + self.audio_buffer = self._pending + + def _do_prefill(self): + """Run the decoder prefill pass over the prompt + first audio embeddings.""" + n_dec_layers = len(self._model.decoder.blocks) + self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)] + + prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len] + prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim] + + logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache) + mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)]) + + self._last_token = self._sample(logits) + mx.async_eval(self._last_token) + + # Remove consumed prefix embeddings + self._audio_embeds = self._audio_embeds[self._prefix_len :] + if self._audio_embeds.shape[0] == 0: + self._audio_embeds = None + self._positions_decoded = self._prefix_len + self._prefilled = True + + def _decode_positions(self, n: int) -> bool: + """Autoregressively decode *n* positions. Returns True on EOS.""" + base_pos = self._positions_decoded # absolute position before this batch + for i in range(n): + tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0] + combined = (self._audio_embeds[i] + tok_embed)[None, None, :] + logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache) + next_tok = self._sample(logits) + mx.async_eval(next_tok) + + token_id = self._last_token.item() + if token_id == self._eos_id: + # Close the current word if one is being built + if self._current_word_pos is not None: + self._word_audio_ends.append(base_pos + i - self._prefix_len) + self._current_word_pos = None + self._trim_embeds(i) + self._positions_decoded += i + return True + + text = self._tokenizer.decode( + [token_id], special_token_policy=SpecialTokenPolicy.IGNORE + ) + + if text: + audio_pos = base_pos + i - self._prefix_len + + # Detect word boundary: new word starts with space or is the very first text + if text.lstrip() != text or not self._full_text: + # Close previous word if exists + if self._current_word_pos is not None: + self._word_audio_ends.append(audio_pos) + # Start new word + self._word_audio_starts.append(audio_pos) + self._current_word_pos = audio_pos + elif self._current_word_pos is None: + # First token of first word (no leading space) + self._word_audio_starts.append(audio_pos) + self._current_word_pos = audio_pos + + self._full_text += text + self._n_text_tokens += 1 + + if i > 0 and i % 256 == 0: + mx.clear_cache() + + self._last_token = next_tok + + self._positions_decoded += n + self._trim_embeds(n) + return False + + def _trim_embeds(self, n_consumed: int): + if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed: + self._audio_embeds = self._audio_embeds[n_consumed:] + else: + self._audio_embeds = None + + def _sample(self, logits: mx.array) -> mx.array: + return mx.argmax(logits[0, -1:], axis=-1).squeeze() + + # -- word extraction -- + + def _audio_pos_to_time(self, pos: int) -> float: + """Convert an audio position (relative to prefix end) to seconds.""" + return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset) + + def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]: + """Compute (start, end) time for a word using tracked word positions.""" + starts = self._word_audio_starts + ends = self._word_audio_ends + + if not starts: + return self._time_offset, self._time_offset + + # Get start position for this word + if word_idx < len(starts): + t0 = self._audio_pos_to_time(starts[word_idx]) + else: + # Fallback: estimate from last known position + last_pos = ends[-1] if ends else starts[-1] + t0 = self._audio_pos_to_time(last_pos + 1) + + # Get end position: use the start of the next word, or the end of this word + if word_idx + 1 < len(starts): + t1 = self._audio_pos_to_time(starts[word_idx + 1]) + elif word_idx < len(ends): + t1 = self._audio_pos_to_time(ends[word_idx] + 1) + else: + # Last word, still being built: use last known position + 1 token + last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0) + t1 = self._audio_pos_to_time(last_pos + 1) + + return t0, t1 + + def _extract_committed_words(self) -> List[ASRToken]: + """Return complete words (all except the last which may still grow).""" + if not self._full_text: + return [] + words = self._full_text.split() + tokens: List[ASRToken] = [] + n_total = max(len(words), 1) + + while len(words) > self._n_committed_words + 1: + w = words[self._n_committed_words] + idx = self._n_committed_words + t0, t1 = self._word_time_range(idx, n_total) + label = w if idx == 0 else " " + w + tokens.append(ASRToken(start=t0, end=t1, text=label)) + self._n_committed_words += 1 + + return tokens + + def _flush_all_words(self) -> List[ASRToken]: + """Flush every word including the last partial one.""" + if not self._full_text: + return [] + words = self._full_text.split() + tokens: List[ASRToken] = [] + n_total = max(len(words), 1) + + while self._n_committed_words < len(words): + w = words[self._n_committed_words] + idx = self._n_committed_words + t0, t1 = self._word_time_range(idx, n_total) + label = w if idx == 0 else " " + w + tokens.append(ASRToken(start=t0, end=t1, text=label)) + self._n_committed_words += 1 + + return tokens + + # -- interface methods -- + + def get_buffer(self) -> Transcript: + if not self._full_text: + return Transcript(start=None, end=None, text="") + words = self._full_text.split() + remaining = words[self._n_committed_words :] + if remaining: + return Transcript(start=self.end, end=self.end, text=" ".join(remaining)) + return Transcript(start=None, end=None, text="") + + def start_silence(self) -> Tuple[List[ASRToken], float]: + words = self._flush_all_words() + logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words)) + return words, self.end + + def end_silence(self, silence_duration: float, offset: float): + self._time_offset += silence_duration + self.end += silence_duration + + def new_speaker(self, change_speaker): + self.start_silence() + + def warmup(self, audio, init_prompt=""): + pass + + def finish(self) -> Tuple[List[ASRToken], float]: + logger.debug( + "[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, " + "samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'", + len(self._pending), + self._audio_embeds.shape if self._audio_embeds is not None else None, + self._samples_encoded, + self._positions_decoded, + self._prefilled, + self._full_text[-80:] if self._full_text else "", + ) + + # Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost + remainder = len(self._pending) % SAMPLES_PER_TOKEN + if remainder > 0: + align_pad = SAMPLES_PER_TOKEN - remainder + else: + align_pad = 0 + + # Add alignment + right-padding silence + total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN + if total_pad > 0: + self._pending = np.append( + self._pending, np.zeros(total_pad, dtype=np.float32) + ) + + # Encode remaining audio (including right-padding) + self._encode_pending() + + logger.debug( + "[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d", + self._audio_embeds.shape if self._audio_embeds is not None else None, + len(self._pending), + ) + + hit_eos = False + + # Decode everything that's left from right-padding + if self._audio_embeds is not None and self._prefilled: + hit_eos = self._decode_positions(self._audio_embeds.shape[0]) + logger.debug( + "[voxtral-mlx] finish decode: hit_eos=%s, text='%s'", + hit_eos, self._full_text[-80:] if self._full_text else "", + ) + + # Flush last token if it wasn't EOS + if self._last_token is not None: + tid = self._last_token.item() + if tid != self._eos_id: + text = self._tokenizer.decode( + [tid], special_token_policy=SpecialTokenPolicy.IGNORE + ) + if text: + last_pos = self._positions_decoded - self._prefix_len + # Check if this starts a new word + if text.lstrip() != text or not self._full_text: + if self._current_word_pos is not None: + self._word_audio_ends.append(last_pos) + self._word_audio_starts.append(last_pos) + self._current_word_pos = last_pos + elif self._current_word_pos is None: + self._word_audio_starts.append(last_pos) + self._current_word_pos = last_pos + self._full_text += text + self._n_text_tokens += 1 + + # Close the last word if still open + if self._current_word_pos is not None: + last_pos = self._positions_decoded - self._prefix_len + self._word_audio_ends.append(last_pos) + self._current_word_pos = None + + words = self._flush_all_words() + logger.info("[voxtral-mlx] finish: flushed %d words", len(words)) + return words, self.end