Files
WhisperLiveKit/whisperlivekit/voxtral_mlx/model.py
Quentin Fuxa a4da246ea5 feat: add voxtral-mlx native backend for Apple Silicon
Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency
speech transcription on Apple Silicon. Avoids the transformers/torch
overhead and runs at 0.18-0.32x real-time factor.

- voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder
- voxtral_mlx/loader.py: model loading with 6-bit quantized weights
- voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX
- voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
2026-02-22 23:28:10 +01:00

535 lines
19 KiB
Python

"""
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
Architecture:
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
with DelayEmbedding providing time-conditioning to the decoder.
The model supports both batch inference (full audio) and incremental streaming
(one chunk at a time with cached encoder/decoder state).
"""
import math
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class SlidingKVCache:
"""Bounded key-value cache with rotating buffer for sliding-window attention.
Uses in-place writes for single-token autoregressive steps and
concatenation for multi-token prefills. Pre-allocates in blocks of
``alloc_step`` entries to reduce repeated allocation.
"""
alloc_step = 256
def __init__(self, capacity: int):
self.capacity = capacity
self.keys = None
self.values = None
self._offset = 0
self._write_idx = 0
@property
def offset(self) -> int:
return self._offset
# -- helpers --
def _reorder(self, buf):
"""Return *buf* in temporal order (unwrap the circular buffer)."""
if self._write_idx == buf.shape[2]:
return buf
if self._write_idx < self._offset:
return mx.concatenate(
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
axis=2,
)
return buf[..., : self._write_idx, :]
def _drop_oldest(self, buf, n_drop, tail=None):
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
if tail is not None:
parts.append(tail)
return mx.concatenate(parts, axis=2)
# -- update strategies --
def _append_concat(self, k, v):
"""Multi-token update via concatenation (used during prefill)."""
if self.keys is None:
self.keys, self.values = k, v
else:
self.keys = self._reorder(self.keys)
self.values = self._reorder(self.values)
self._write_idx = self.keys.shape[2]
overflow = self._write_idx - self.capacity + 1
self.keys = self._drop_oldest(self.keys, overflow, k)
self.values = self._drop_oldest(self.values, overflow, v)
self._offset += k.shape[2]
self._write_idx = self.keys.shape[2]
return self.keys, self.values
def _write_inplace(self, k, v):
"""Single-token update via in-place write (autoregressive step)."""
B, n_heads, S, dim_k = k.shape
dim_v = v.shape[3]
prev = self._offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
):
n_new = min(self.alloc_step, self.capacity - prev)
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
self.values = mx.concatenate([self.values, fresh_v], axis=2)
else:
self.keys, self.values = fresh_k, fresh_v
self._write_idx = prev
overflow = self.keys.shape[2] - self.capacity
if overflow > 0:
self.keys = self._drop_oldest(self.keys, overflow)
self.values = self._drop_oldest(self.values, overflow)
self._write_idx = self.capacity
if self._write_idx == self.capacity:
self._write_idx = 0
self.keys[..., self._write_idx : self._write_idx + S, :] = k
self.values[..., self._write_idx : self._write_idx + S, :] = v
self._offset += S
self._write_idx += S
if self._offset < self.capacity:
return (
self.keys[..., : self._offset, :],
self.values[..., : self._offset, :],
)
return self.keys, self.values
# -- public API --
def update_and_fetch(self, k, v):
if k.shape[2] == 1:
return self._write_inplace(k, v)
return self._append_concat(k, v)
# ---------------------------------------------------------------------------
# Encoder components
# ---------------------------------------------------------------------------
class CausalConv(nn.Module):
"""1-D causal convolution (left-padded so no future leakage)."""
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
super().__init__()
self.stride = stride
self.kernel = kernel
self.left_pad = kernel - stride
self.weight = mx.zeros((channels_out, kernel, channels_in))
self.bias = mx.zeros((channels_out,))
def __call__(self, x: mx.array) -> mx.array:
if self.left_pad > 0:
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
class _EncoderSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope_theta = rope_theta
def __call__(self, x, mask, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _EncoderFFN(nn.Module):
"""SwiGLU feed-forward for encoder layers."""
def __init__(self, dim: int, hidden: int):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=True)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class _EncoderBlock(nn.Module):
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _EncoderFFN(dim, hidden)
def __call__(self, x, mask, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
x = x + self.ffn(self.pre_ffn_norm(x))
return x
class StreamingEncoder(nn.Module):
"""Causal Whisper-style encoder with two causal convolutions followed by
a stack of transformer blocks. Supports both full-sequence and
incremental (streaming) forward passes."""
def __init__(
self,
mel_channels: int = 128,
dim: int = 1280,
n_layers: int = 32,
n_heads: int = 32,
head_dim: int = 64,
hidden_dim: int = 5120,
rope_theta: float = 1e6,
sliding_window: int = 750,
):
super().__init__()
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
self.blocks = [
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
self.sliding_window = sliding_window
# -- full-sequence --
def _apply_convs(self, mel: mx.array) -> mx.array:
x = mel.T[None, :, :] # [1, T, mel_channels]
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
return x
def forward(self, mel: mx.array) -> mx.array:
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
for blk in self.blocks:
x = blk(x, mask="causal")
return self.final_norm(x)
# -- incremental (streaming) --
def forward_conv_incremental(self, x_in, tail1, tail2):
"""Process new mel frames through the two causal convs using cached tails.
Args:
x_in: [1, N, mel_channels]
tail1: [1, pad1, mel_channels] or None (first call)
tail2: [1, pad2, dim] or None (first call)
Returns:
(out, new_tail1, new_tail2)
"""
# Conv1 (kernel=3, stride=1 → left_pad=2)
if tail1 is not None:
c1_in = mx.concatenate([tail1, x_in], axis=1)
else:
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
c1_out = nn.gelu(
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
)
# Conv2 (kernel=3, stride=2 → left_pad=1)
if tail2 is not None:
c2_in = mx.concatenate([tail2, c1_out], axis=1)
else:
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
c2_out = nn.gelu(
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
)
return c2_out, new_tail1, new_tail2
def forward_transformer_incremental(self, x, cache_list):
"""Run transformer blocks with per-layer KV caches."""
for i, blk in enumerate(self.blocks):
x = blk(x, mask="causal", cache=cache_list[i])
return self.final_norm(x)
# ---------------------------------------------------------------------------
# Decoder components
# ---------------------------------------------------------------------------
class _DecoderAttention(nn.Module):
"""Grouped-query attention for the text decoder."""
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_theta = rope_theta
def __call__(self, x, mask=None, cache=None):
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
pos = cache.offset if cache is not None else 0
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class _DecoderFFN(nn.Module):
"""SwiGLU feed-forward for decoder layers."""
def __init__(self, dim, hidden):
super().__init__()
self.gate = nn.Linear(dim, hidden, bias=False)
self.up = nn.Linear(dim, hidden, bias=False)
self.down = nn.Linear(hidden, dim, bias=False)
def __call__(self, x):
return self.down(nn.silu(self.gate(x)) * self.up(x))
class AdaptiveScaling(nn.Module):
"""Small MLP that produces a multiplicative scale from the delay embedding,
used to condition the FFN on the streaming delay."""
def __init__(self, dim, bottleneck):
super().__init__()
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
def __call__(self, cond):
return self.proj_out(nn.gelu(self.proj_in(cond)))
class _DecoderBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
super().__init__()
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
self.ffn = _DecoderFFN(dim, hidden)
def __call__(self, x, delay_cond, mask=None, cache=None):
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
x = x + self.ffn(scaled)
return x
class TextDecoder(nn.Module):
"""Mistral-style causal language model with adaptive time-conditioning."""
def __init__(
self,
dim: int = 3072,
n_layers: int = 26,
n_heads: int = 32,
n_kv_heads: int = 8,
head_dim: int = 128,
hidden_dim: int = 9216,
vocab_size: int = 131072,
rope_theta: float = 1e6,
cond_dim: int = 32,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, dim)
self.blocks = [
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
for _ in range(n_layers)
]
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
def embed(self, token_ids: mx.array) -> mx.array:
return self.token_embedding(token_ids)
def __call__(self, x, delay_cond, mask=None, cache=None):
delay_cond = delay_cond.astype(x.dtype)
for i, blk in enumerate(self.blocks):
blk_cache = cache[i] if cache is not None else None
x = blk(x, delay_cond, mask, blk_cache)
x = self.final_norm(x)
return self.token_embedding.as_linear(x)
# ---------------------------------------------------------------------------
# Adapter & embeddings
# ---------------------------------------------------------------------------
class EncoderToDecoderAdapter(nn.Module):
"""Two-layer projection from encoder space to decoder space."""
def __init__(self, enc_dim: int, dec_dim: int):
super().__init__()
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
def __call__(self, x):
return self.linear2(nn.gelu(self.linear1(x)))
class DelayEmbedding(nn.Module):
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
vector for the decoder's adaptive scaling."""
def __init__(self, dim: int = 3072, theta: float = 10000.0):
super().__init__()
self.dim = dim
half = dim // 2
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
self._freqs = freqs
def __call__(self, delay: mx.array) -> mx.array:
t = delay.reshape(-1, 1).astype(mx.float32)
angles = t * self._freqs
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class VoxtralMLXModel(nn.Module):
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
def __init__(self, config: dict):
super().__init__()
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
audio_cfg = enc_cfg["audio_encoding_args"]
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
self.encoder = StreamingEncoder(
mel_channels=audio_cfg["num_mel_bins"],
dim=enc_cfg["dim"],
n_layers=enc_cfg["n_layers"],
n_heads=enc_cfg["n_heads"],
head_dim=enc_cfg["head_dim"],
hidden_dim=enc_cfg["hidden_dim"],
rope_theta=enc_cfg["rope_theta"],
sliding_window=enc_cfg["sliding_window"],
)
adapter_input_dim = enc_cfg["dim"] * ds_factor
decoder_dim = config["dim"]
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
self.decoder = TextDecoder(
dim=decoder_dim,
n_layers=config["n_layers"],
n_heads=config["n_heads"],
n_kv_heads=config["n_kv_heads"],
head_dim=config["head_dim"],
hidden_dim=config["hidden_dim"],
vocab_size=config["vocab_size"],
rope_theta=config["rope_theta"],
cond_dim=cond_bottleneck,
)
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
self.ds_factor = ds_factor
# -- batch encode --
def encode(self, mel: mx.array) -> mx.array:
T = mel.shape[1]
if T % 2 != 0:
mel = mel[:, 1:]
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
h = h[0]
n = h.shape[0]
trim = n % self.ds_factor
if trim:
h = h[trim:]
n = h.shape[0]
h = h.reshape(n // self.ds_factor, -1)
return self.adapter(h)
# -- incremental encode --
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
"""Incrementally encode new mel frames.
Returns:
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
"""
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
if enc_cache is None:
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
x = self.encoder.forward_transformer_incremental(x, enc_cache)
x = x[0] # [N, enc_dim]
if ds_remainder is not None:
x = mx.concatenate([ds_remainder, x])
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
if n_full == 0:
return None, conv_tail1, conv_tail2, enc_cache, x
leftover = x[n_full:] if x.shape[0] > n_full else None
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
# -- decode --
def decode(self, embeddings, delay_cond, mask=None, cache=None):
return self.decoder(embeddings, delay_cond, mask, cache)