mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
feat: add voxtral-mlx native backend for Apple Silicon
Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency speech transcription on Apple Silicon. Avoids the transformers/torch overhead and runs at 0.18-0.32x real-time factor. - voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder - voxtral_mlx/loader.py: model loading with 6-bit quantized weights - voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX - voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
This commit is contained in:
6
whisperlivekit/voxtral_mlx/__init__.py
Normal file
6
whisperlivekit/voxtral_mlx/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Pure-MLX Voxtral Realtime backend for WhisperLiveKit."""
|
||||
|
||||
from .loader import load_voxtral_model
|
||||
from .model import VoxtralMLXModel
|
||||
|
||||
__all__ = ["load_voxtral_model", "VoxtralMLXModel"]
|
||||
282
whisperlivekit/voxtral_mlx/loader.py
Normal file
282
whisperlivekit/voxtral_mlx/loader.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Model weight loading for the MLX Voxtral Realtime backend.
|
||||
|
||||
Supports two on-disk formats:
|
||||
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
|
||||
with optional quantisation metadata.
|
||||
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
|
||||
requires weight renaming and conv-weight transposition.
|
||||
|
||||
The public entry point is :func:`load_voxtral_model` which returns the
|
||||
model, tokenizer, and raw config dict.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .model import VoxtralMLXModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Downloading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ALLOWED_PATTERNS = [
|
||||
"consolidated.safetensors",
|
||||
"model*.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"params.json",
|
||||
"config.json",
|
||||
"tekken.json",
|
||||
]
|
||||
|
||||
|
||||
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
|
||||
"""Download model files from HuggingFace Hub and return the local path."""
|
||||
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight name remapping (Mistral → our naming)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NAME_RULES: list[tuple[str, str]] = [
|
||||
# Encoder convolutions
|
||||
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
|
||||
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
|
||||
# Encoder transformer blocks
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
|
||||
r"encoder.blocks.\1.self_attn.q_proj.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
|
||||
r"encoder.blocks.\1.self_attn.k_proj.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
|
||||
r"encoder.blocks.\1.self_attn.v_proj.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
|
||||
r"encoder.blocks.\1.self_attn.out_proj.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
|
||||
r"encoder.blocks.\1.pre_attn_norm.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
|
||||
r"encoder.blocks.\1.ffn.gate.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
|
||||
r"encoder.blocks.\1.ffn.down.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
|
||||
r"encoder.blocks.\1.ffn.up.\2"),
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
|
||||
r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
||||
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
|
||||
# Adapter
|
||||
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
|
||||
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
|
||||
# Decoder embedding
|
||||
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
|
||||
# Decoder blocks
|
||||
(r"layers\.(\d+)\.attention\.wq\.weight",
|
||||
r"decoder.blocks.\1.self_attn.q_proj.weight"),
|
||||
(r"layers\.(\d+)\.attention\.wk\.weight",
|
||||
r"decoder.blocks.\1.self_attn.k_proj.weight"),
|
||||
(r"layers\.(\d+)\.attention\.wv\.weight",
|
||||
r"decoder.blocks.\1.self_attn.v_proj.weight"),
|
||||
(r"layers\.(\d+)\.attention\.wo\.weight",
|
||||
r"decoder.blocks.\1.self_attn.out_proj.weight"),
|
||||
(r"layers\.(\d+)\.attention_norm\.weight",
|
||||
r"decoder.blocks.\1.pre_attn_norm.weight"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
|
||||
r"decoder.blocks.\1.ffn.gate.weight"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
|
||||
r"decoder.blocks.\1.ffn.down.weight"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
|
||||
r"decoder.blocks.\1.ffn.up.weight"),
|
||||
(r"layers\.(\d+)\.ffn_norm\.weight",
|
||||
r"decoder.blocks.\1.pre_ffn_norm.weight"),
|
||||
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
|
||||
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
|
||||
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
|
||||
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
|
||||
# Decoder final norm
|
||||
(r"norm\.weight", r"decoder.final_norm.weight"),
|
||||
]
|
||||
|
||||
_PREFIX_STRIP = re.compile(
|
||||
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
|
||||
)
|
||||
|
||||
|
||||
def _translate_weight_name(name: str) -> str | None:
|
||||
name = _PREFIX_STRIP.sub("", name)
|
||||
for pattern, replacement in _NAME_RULES:
|
||||
result, n = re.subn(f"^{pattern}$", replacement, name)
|
||||
if n:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _is_conv_weight(name: str) -> bool:
|
||||
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Converted-format weight remapping (voxmlx names → our names)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONVERTED_RULES: list[tuple[str, str]] = [
|
||||
# Adapter
|
||||
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
|
||||
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
|
||||
# Encoder transformer blocks
|
||||
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
|
||||
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
|
||||
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
|
||||
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
|
||||
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
|
||||
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
||||
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
|
||||
# Decoder embedding
|
||||
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
|
||||
# Decoder blocks
|
||||
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
|
||||
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
|
||||
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
|
||||
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
|
||||
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
|
||||
]
|
||||
|
||||
# Also remap o_proj → out_proj in both encoder and decoder
|
||||
_POST_RENAME = [
|
||||
(r"\.o_proj\.", r".out_proj."),
|
||||
]
|
||||
|
||||
|
||||
def _remap_converted_name(name: str) -> str:
|
||||
"""Translate a converted-format weight name to our naming convention."""
|
||||
for pattern, replacement in _CONVERTED_RULES:
|
||||
result, n = re.subn(f"^{pattern}$", replacement, name)
|
||||
if n:
|
||||
name = result
|
||||
break
|
||||
for pattern, replacement in _POST_RENAME:
|
||||
name = re.sub(pattern, replacement, name)
|
||||
return name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loading strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _has_converted_layout(path: Path) -> bool:
|
||||
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
|
||||
|
||||
|
||||
def _load_converted_weights(path: Path):
|
||||
with open(path / "config.json") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = VoxtralMLXModel(config)
|
||||
|
||||
quant = config.get("quantization")
|
||||
if quant is not None:
|
||||
gs = quant["group_size"]
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=gs,
|
||||
bits=quant["bits"],
|
||||
class_predicate=lambda _p, m: (
|
||||
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
|
||||
),
|
||||
)
|
||||
|
||||
index_file = path / "model.safetensors.index.json"
|
||||
if index_file.exists():
|
||||
with open(index_file) as f:
|
||||
shard_map = json.load(f)
|
||||
shard_files = sorted(set(shard_map["weight_map"].values()))
|
||||
weights = {}
|
||||
for sf in shard_files:
|
||||
weights.update(mx.load(str(path / sf)))
|
||||
else:
|
||||
weights = mx.load(str(path / "model.safetensors"))
|
||||
|
||||
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
|
||||
model.load_weights(list(remapped.items()))
|
||||
mx.eval(model.parameters())
|
||||
return model, config
|
||||
|
||||
|
||||
def _load_original_weights(path: Path):
|
||||
with open(path / "params.json") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = VoxtralMLXModel(config)
|
||||
|
||||
raw = mx.load(str(path / "consolidated.safetensors"))
|
||||
mapped: dict[str, mx.array] = {}
|
||||
skipped: list[str] = []
|
||||
|
||||
for name, tensor in raw.items():
|
||||
if name == "output.weight":
|
||||
continue
|
||||
new_name = _translate_weight_name(name)
|
||||
if new_name is None:
|
||||
skipped.append(name)
|
||||
continue
|
||||
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
|
||||
if _is_conv_weight(new_name):
|
||||
tensor = mx.swapaxes(tensor, 1, 2)
|
||||
mapped[new_name] = tensor
|
||||
|
||||
if skipped:
|
||||
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
|
||||
|
||||
model.load_weights(list(mapped.items()))
|
||||
mx.eval(model.parameters())
|
||||
return model, config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_tokenizer(model_dir: Path):
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
|
||||
"""Load a Voxtral Realtime model and its tokenizer.
|
||||
|
||||
Args:
|
||||
path_or_id: Local directory path **or** a HuggingFace model ID.
|
||||
|
||||
Returns:
|
||||
``(model, tokenizer, config)``
|
||||
"""
|
||||
p = Path(path_or_id)
|
||||
if not p.exists():
|
||||
p = download_weights(path_or_id)
|
||||
|
||||
if _has_converted_layout(p):
|
||||
model, config = _load_converted_weights(p)
|
||||
else:
|
||||
model, config = _load_original_weights(p)
|
||||
|
||||
tokenizer = _load_tokenizer(p)
|
||||
logger.info("Voxtral MLX model loaded from %s", p)
|
||||
return model, tokenizer, config
|
||||
534
whisperlivekit/voxtral_mlx/model.py
Normal file
534
whisperlivekit/voxtral_mlx/model.py
Normal file
@@ -0,0 +1,534 @@
|
||||
"""
|
||||
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
|
||||
|
||||
Architecture:
|
||||
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
|
||||
with DelayEmbedding providing time-conditioning to the decoder.
|
||||
|
||||
The model supports both batch inference (full audio) and incremental streaming
|
||||
(one chunk at a time with cached encoder/decoder state).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SlidingKVCache:
|
||||
"""Bounded key-value cache with rotating buffer for sliding-window attention.
|
||||
|
||||
Uses in-place writes for single-token autoregressive steps and
|
||||
concatenation for multi-token prefills. Pre-allocates in blocks of
|
||||
``alloc_step`` entries to reduce repeated allocation.
|
||||
"""
|
||||
|
||||
alloc_step = 256
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.capacity = capacity
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self._offset = 0
|
||||
self._write_idx = 0
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
return self._offset
|
||||
|
||||
# -- helpers --
|
||||
|
||||
def _reorder(self, buf):
|
||||
"""Return *buf* in temporal order (unwrap the circular buffer)."""
|
||||
if self._write_idx == buf.shape[2]:
|
||||
return buf
|
||||
if self._write_idx < self._offset:
|
||||
return mx.concatenate(
|
||||
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
|
||||
axis=2,
|
||||
)
|
||||
return buf[..., : self._write_idx, :]
|
||||
|
||||
def _drop_oldest(self, buf, n_drop, tail=None):
|
||||
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
|
||||
if tail is not None:
|
||||
parts.append(tail)
|
||||
return mx.concatenate(parts, axis=2)
|
||||
|
||||
# -- update strategies --
|
||||
|
||||
def _append_concat(self, k, v):
|
||||
"""Multi-token update via concatenation (used during prefill)."""
|
||||
if self.keys is None:
|
||||
self.keys, self.values = k, v
|
||||
else:
|
||||
self.keys = self._reorder(self.keys)
|
||||
self.values = self._reorder(self.values)
|
||||
self._write_idx = self.keys.shape[2]
|
||||
overflow = self._write_idx - self.capacity + 1
|
||||
self.keys = self._drop_oldest(self.keys, overflow, k)
|
||||
self.values = self._drop_oldest(self.values, overflow, v)
|
||||
self._offset += k.shape[2]
|
||||
self._write_idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _write_inplace(self, k, v):
|
||||
"""Single-token update via in-place write (autoregressive step)."""
|
||||
B, n_heads, S, dim_k = k.shape
|
||||
dim_v = v.shape[3]
|
||||
prev = self._offset
|
||||
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
|
||||
):
|
||||
n_new = min(self.alloc_step, self.capacity - prev)
|
||||
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
|
||||
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, fresh_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = fresh_k, fresh_v
|
||||
self._write_idx = prev
|
||||
|
||||
overflow = self.keys.shape[2] - self.capacity
|
||||
if overflow > 0:
|
||||
self.keys = self._drop_oldest(self.keys, overflow)
|
||||
self.values = self._drop_oldest(self.values, overflow)
|
||||
self._write_idx = self.capacity
|
||||
|
||||
if self._write_idx == self.capacity:
|
||||
self._write_idx = 0
|
||||
|
||||
self.keys[..., self._write_idx : self._write_idx + S, :] = k
|
||||
self.values[..., self._write_idx : self._write_idx + S, :] = v
|
||||
self._offset += S
|
||||
self._write_idx += S
|
||||
|
||||
if self._offset < self.capacity:
|
||||
return (
|
||||
self.keys[..., : self._offset, :],
|
||||
self.values[..., : self._offset, :],
|
||||
)
|
||||
return self.keys, self.values
|
||||
|
||||
# -- public API --
|
||||
|
||||
def update_and_fetch(self, k, v):
|
||||
if k.shape[2] == 1:
|
||||
return self._write_inplace(k, v)
|
||||
return self._append_concat(k, v)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder components
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CausalConv(nn.Module):
|
||||
"""1-D causal convolution (left-padded so no future leakage)."""
|
||||
|
||||
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.kernel = kernel
|
||||
self.left_pad = kernel - stride
|
||||
self.weight = mx.zeros((channels_out, kernel, channels_in))
|
||||
self.bias = mx.zeros((channels_out,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.left_pad > 0:
|
||||
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
|
||||
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
|
||||
|
||||
|
||||
class _EncoderSelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
def __call__(self, x, mask, cache=None):
|
||||
B, L, _ = x.shape
|
||||
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
pos = cache.offset if cache is not None else 0
|
||||
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||
|
||||
if cache is not None:
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
||||
|
||||
|
||||
class _EncoderFFN(nn.Module):
|
||||
"""SwiGLU feed-forward for encoder layers."""
|
||||
|
||||
def __init__(self, dim: int, hidden: int):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(dim, hidden, bias=False)
|
||||
self.up = nn.Linear(dim, hidden, bias=False)
|
||||
self.down = nn.Linear(hidden, dim, bias=True)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
||||
|
||||
|
||||
class _EncoderBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
|
||||
super().__init__()
|
||||
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
|
||||
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
self.ffn = _EncoderFFN(dim, hidden)
|
||||
|
||||
def __call__(self, x, mask, cache=None):
|
||||
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
|
||||
x = x + self.ffn(self.pre_ffn_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class StreamingEncoder(nn.Module):
|
||||
"""Causal Whisper-style encoder with two causal convolutions followed by
|
||||
a stack of transformer blocks. Supports both full-sequence and
|
||||
incremental (streaming) forward passes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mel_channels: int = 128,
|
||||
dim: int = 1280,
|
||||
n_layers: int = 32,
|
||||
n_heads: int = 32,
|
||||
head_dim: int = 64,
|
||||
hidden_dim: int = 5120,
|
||||
rope_theta: float = 1e6,
|
||||
sliding_window: int = 750,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
|
||||
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
|
||||
self.blocks = [
|
||||
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# -- full-sequence --
|
||||
|
||||
def _apply_convs(self, mel: mx.array) -> mx.array:
|
||||
x = mel.T[None, :, :] # [1, T, mel_channels]
|
||||
x = nn.gelu(self.conv1(x))
|
||||
x = nn.gelu(self.conv2(x))
|
||||
return x
|
||||
|
||||
def forward(self, mel: mx.array) -> mx.array:
|
||||
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
|
||||
for blk in self.blocks:
|
||||
x = blk(x, mask="causal")
|
||||
return self.final_norm(x)
|
||||
|
||||
# -- incremental (streaming) --
|
||||
|
||||
def forward_conv_incremental(self, x_in, tail1, tail2):
|
||||
"""Process new mel frames through the two causal convs using cached tails.
|
||||
|
||||
Args:
|
||||
x_in: [1, N, mel_channels]
|
||||
tail1: [1, pad1, mel_channels] or None (first call)
|
||||
tail2: [1, pad2, dim] or None (first call)
|
||||
|
||||
Returns:
|
||||
(out, new_tail1, new_tail2)
|
||||
"""
|
||||
# Conv1 (kernel=3, stride=1 → left_pad=2)
|
||||
if tail1 is not None:
|
||||
c1_in = mx.concatenate([tail1, x_in], axis=1)
|
||||
else:
|
||||
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
|
||||
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
|
||||
c1_out = nn.gelu(
|
||||
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
|
||||
)
|
||||
|
||||
# Conv2 (kernel=3, stride=2 → left_pad=1)
|
||||
if tail2 is not None:
|
||||
c2_in = mx.concatenate([tail2, c1_out], axis=1)
|
||||
else:
|
||||
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
|
||||
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
|
||||
c2_out = nn.gelu(
|
||||
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
|
||||
)
|
||||
|
||||
return c2_out, new_tail1, new_tail2
|
||||
|
||||
def forward_transformer_incremental(self, x, cache_list):
|
||||
"""Run transformer blocks with per-layer KV caches."""
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x, mask="causal", cache=cache_list[i])
|
||||
return self.final_norm(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder components
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DecoderAttention(nn.Module):
|
||||
"""Grouped-query attention for the text decoder."""
|
||||
|
||||
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
B, L, _ = x.shape
|
||||
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
pos = cache.offset if cache is not None else 0
|
||||
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
||||
|
||||
if cache is not None:
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
||||
|
||||
|
||||
class _DecoderFFN(nn.Module):
|
||||
"""SwiGLU feed-forward for decoder layers."""
|
||||
|
||||
def __init__(self, dim, hidden):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(dim, hidden, bias=False)
|
||||
self.up = nn.Linear(dim, hidden, bias=False)
|
||||
self.down = nn.Linear(hidden, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
||||
|
||||
|
||||
class AdaptiveScaling(nn.Module):
|
||||
"""Small MLP that produces a multiplicative scale from the delay embedding,
|
||||
used to condition the FFN on the streaming delay."""
|
||||
|
||||
def __init__(self, dim, bottleneck):
|
||||
super().__init__()
|
||||
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
|
||||
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
|
||||
|
||||
def __call__(self, cond):
|
||||
return self.proj_out(nn.gelu(self.proj_in(cond)))
|
||||
|
||||
|
||||
class _DecoderBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
|
||||
super().__init__()
|
||||
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
|
||||
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
|
||||
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
self.ffn = _DecoderFFN(dim, hidden)
|
||||
|
||||
def __call__(self, x, delay_cond, mask=None, cache=None):
|
||||
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
|
||||
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
|
||||
x = x + self.ffn(scaled)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
"""Mistral-style causal language model with adaptive time-conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3072,
|
||||
n_layers: int = 26,
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
hidden_dim: int = 9216,
|
||||
vocab_size: int = 131072,
|
||||
rope_theta: float = 1e6,
|
||||
cond_dim: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.blocks = [
|
||||
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
||||
|
||||
def embed(self, token_ids: mx.array) -> mx.array:
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def __call__(self, x, delay_cond, mask=None, cache=None):
|
||||
delay_cond = delay_cond.astype(x.dtype)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
blk_cache = cache[i] if cache is not None else None
|
||||
x = blk(x, delay_cond, mask, blk_cache)
|
||||
x = self.final_norm(x)
|
||||
return self.token_embedding.as_linear(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter & embeddings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EncoderToDecoderAdapter(nn.Module):
|
||||
"""Two-layer projection from encoder space to decoder space."""
|
||||
|
||||
def __init__(self, enc_dim: int, dec_dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
|
||||
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.linear2(nn.gelu(self.linear1(x)))
|
||||
|
||||
|
||||
class DelayEmbedding(nn.Module):
|
||||
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
|
||||
vector for the decoder's adaptive scaling."""
|
||||
|
||||
def __init__(self, dim: int = 3072, theta: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
half = dim // 2
|
||||
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
|
||||
self._freqs = freqs
|
||||
|
||||
def __call__(self, delay: mx.array) -> mx.array:
|
||||
t = delay.reshape(-1, 1).astype(mx.float32)
|
||||
angles = t * self._freqs
|
||||
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VoxtralMLXModel(nn.Module):
|
||||
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
|
||||
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
|
||||
audio_cfg = enc_cfg["audio_encoding_args"]
|
||||
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
|
||||
|
||||
self.encoder = StreamingEncoder(
|
||||
mel_channels=audio_cfg["num_mel_bins"],
|
||||
dim=enc_cfg["dim"],
|
||||
n_layers=enc_cfg["n_layers"],
|
||||
n_heads=enc_cfg["n_heads"],
|
||||
head_dim=enc_cfg["head_dim"],
|
||||
hidden_dim=enc_cfg["hidden_dim"],
|
||||
rope_theta=enc_cfg["rope_theta"],
|
||||
sliding_window=enc_cfg["sliding_window"],
|
||||
)
|
||||
|
||||
adapter_input_dim = enc_cfg["dim"] * ds_factor
|
||||
decoder_dim = config["dim"]
|
||||
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
|
||||
|
||||
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
|
||||
|
||||
self.decoder = TextDecoder(
|
||||
dim=decoder_dim,
|
||||
n_layers=config["n_layers"],
|
||||
n_heads=config["n_heads"],
|
||||
n_kv_heads=config["n_kv_heads"],
|
||||
head_dim=config["head_dim"],
|
||||
hidden_dim=config["hidden_dim"],
|
||||
vocab_size=config["vocab_size"],
|
||||
rope_theta=config["rope_theta"],
|
||||
cond_dim=cond_bottleneck,
|
||||
)
|
||||
|
||||
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
|
||||
self.ds_factor = ds_factor
|
||||
|
||||
# -- batch encode --
|
||||
|
||||
def encode(self, mel: mx.array) -> mx.array:
|
||||
T = mel.shape[1]
|
||||
if T % 2 != 0:
|
||||
mel = mel[:, 1:]
|
||||
|
||||
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
|
||||
h = h[0]
|
||||
|
||||
n = h.shape[0]
|
||||
trim = n % self.ds_factor
|
||||
if trim:
|
||||
h = h[trim:]
|
||||
n = h.shape[0]
|
||||
|
||||
h = h.reshape(n // self.ds_factor, -1)
|
||||
return self.adapter(h)
|
||||
|
||||
# -- incremental encode --
|
||||
|
||||
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
|
||||
"""Incrementally encode new mel frames.
|
||||
|
||||
Returns:
|
||||
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
|
||||
"""
|
||||
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
|
||||
|
||||
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
|
||||
|
||||
if enc_cache is None:
|
||||
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
|
||||
|
||||
x = self.encoder.forward_transformer_incremental(x, enc_cache)
|
||||
x = x[0] # [N, enc_dim]
|
||||
|
||||
if ds_remainder is not None:
|
||||
x = mx.concatenate([ds_remainder, x])
|
||||
|
||||
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
|
||||
if n_full == 0:
|
||||
return None, conv_tail1, conv_tail2, enc_cache, x
|
||||
|
||||
leftover = x[n_full:] if x.shape[0] > n_full else None
|
||||
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
|
||||
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
|
||||
|
||||
# -- decode --
|
||||
|
||||
def decode(self, embeddings, delay_cond, mask=None, cache=None):
|
||||
return self.decoder(embeddings, delay_cond, mask, cache)
|
||||
202
whisperlivekit/voxtral_mlx/spectrogram.py
Normal file
202
whisperlivekit/voxtral_mlx/spectrogram.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Mel spectrogram computation for Voxtral Realtime.
|
||||
|
||||
Provides both a full-audio function and an incremental streaming variant
|
||||
that maintains overlap state between calls. The DFT is computed via
|
||||
matrix multiplication in MLX — no external FFT dependency required.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
# Audio / mel constants matching the Voxtral Realtime model expectations.
|
||||
SAMPLE_RATE = 16_000
|
||||
WINDOW_SIZE = 400 # n_fft
|
||||
HOP = 160
|
||||
MEL_BANDS = 128
|
||||
MEL_MAX = 1.5 # global log-mel normalisation ceiling
|
||||
# Each output audio token spans: hop * conv_stride(2) * downsample_factor(4)
|
||||
SAMPLES_PER_TOKEN = HOP * 2 * 4 # = 1280 samples = 80 ms
|
||||
|
||||
# Padding tokens used by the model prompt structure.
|
||||
LEFT_PAD_TOKENS = 32
|
||||
RIGHT_PAD_TOKENS = 17
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slaney mel filterbank
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_slaney_filterbank(
|
||||
sr: int = SAMPLE_RATE,
|
||||
n_fft: int = WINDOW_SIZE,
|
||||
n_mels: int = MEL_BANDS,
|
||||
lo_hz: float = 0.0,
|
||||
hi_hz: float = 8000.0,
|
||||
) -> np.ndarray:
|
||||
"""Compute a Slaney-normalised triangular mel filterbank.
|
||||
|
||||
Returns an array of shape ``[n_mels, n_fft//2 + 1]``.
|
||||
"""
|
||||
|
||||
def _hz2mel(f):
|
||||
threshold = 1000.0
|
||||
base_mel = 15.0
|
||||
log_coeff = 27.0 / np.log(6.4)
|
||||
mel = 3.0 * f / 200.0
|
||||
if isinstance(f, np.ndarray):
|
||||
above = f >= threshold
|
||||
mel[above] = base_mel + np.log(f[above] / threshold) * log_coeff
|
||||
elif f >= threshold:
|
||||
mel = base_mel + np.log(f / threshold) * log_coeff
|
||||
return mel
|
||||
|
||||
def _mel2hz(m):
|
||||
threshold = 1000.0
|
||||
base_mel = 15.0
|
||||
log_coeff = np.log(6.4) / 27.0
|
||||
hz = 200.0 * m / 3.0
|
||||
above = m >= base_mel
|
||||
hz[above] = threshold * np.exp(log_coeff * (m[above] - base_mel))
|
||||
return hz
|
||||
|
||||
n_bins = n_fft // 2 + 1
|
||||
fft_hz = np.linspace(0, sr / 2, n_bins)
|
||||
mel_lo, mel_hi = _hz2mel(lo_hz), _hz2mel(hi_hz)
|
||||
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
|
||||
hz_pts = _mel2hz(mel_pts)
|
||||
diffs = np.diff(hz_pts)
|
||||
|
||||
slopes = np.expand_dims(hz_pts, 0) - np.expand_dims(fft_hz, 1)
|
||||
rising = -slopes[:, :-2] / diffs[:-1]
|
||||
falling = slopes[:, 2:] / diffs[1:]
|
||||
fb = np.maximum(0.0, np.minimum(rising, falling))
|
||||
|
||||
# Slaney area normalisation
|
||||
widths = 2.0 / (hz_pts[2 : n_mels + 2] - hz_pts[:n_mels])
|
||||
fb *= np.expand_dims(widths, 0)
|
||||
return fb.T.astype(np.float32)
|
||||
|
||||
|
||||
_CACHED_FILTERS: mx.array | None = None
|
||||
|
||||
|
||||
def _mel_filters() -> mx.array:
|
||||
global _CACHED_FILTERS
|
||||
if _CACHED_FILTERS is None:
|
||||
_CACHED_FILTERS = mx.array(_build_slaney_filterbank())
|
||||
return _CACHED_FILTERS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DFT helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _hann_window() -> mx.array:
|
||||
return mx.array(np.hanning(WINDOW_SIZE + 1)[:-1].astype(np.float32))
|
||||
|
||||
|
||||
def _dft_matrices():
|
||||
"""Pre-compute the real / imaginary DFT basis matrices."""
|
||||
n_bins = WINDOW_SIZE // 2 + 1
|
||||
k = mx.arange(n_bins, dtype=mx.float32)[:, None]
|
||||
n = mx.arange(WINDOW_SIZE, dtype=mx.float32)[None, :]
|
||||
phase = -2.0 * math.pi * (k @ n) / WINDOW_SIZE
|
||||
return mx.cos(phase), mx.sin(phase)
|
||||
|
||||
|
||||
def _stft_frames(audio: mx.array, window: mx.array) -> mx.array:
|
||||
"""Frame *audio* using the Hann window and compute power spectrogram."""
|
||||
n_bins = WINDOW_SIZE // 2 + 1
|
||||
n_frames = 1 + (audio.shape[0] - WINDOW_SIZE) // HOP
|
||||
if n_frames <= 0:
|
||||
return mx.zeros((0, n_bins))
|
||||
|
||||
offsets = (mx.arange(n_frames) * HOP)[:, None]
|
||||
indices = offsets + mx.arange(WINDOW_SIZE)[None, :]
|
||||
windowed = audio[indices] * window[None, :]
|
||||
|
||||
dft_re, dft_im = _dft_matrices()
|
||||
real_part = windowed @ dft_re.T
|
||||
imag_part = windowed @ dft_im.T
|
||||
return real_part ** 2 + imag_part ** 2
|
||||
|
||||
|
||||
def _apply_mel_and_log(power: mx.array) -> mx.array:
|
||||
"""Convert a power spectrogram to log-mel and normalise."""
|
||||
mel = power @ _mel_filters().T
|
||||
log_mel = mx.log10(mx.maximum(mel, 1e-10))
|
||||
log_mel = mx.maximum(log_mel, MEL_MAX - 8.0)
|
||||
return (log_mel + 4.0) / 4.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_mel(audio: np.ndarray) -> mx.array:
|
||||
"""Compute log-mel spectrogram for a complete audio signal.
|
||||
|
||||
Args:
|
||||
audio: 1-D float32 numpy array at ``SAMPLE_RATE``.
|
||||
|
||||
Returns:
|
||||
``[MEL_BANDS, T]`` MLX array.
|
||||
"""
|
||||
x = mx.array(audio)
|
||||
pad = WINDOW_SIZE // 2
|
||||
x = mx.pad(x, [(pad, pad)])
|
||||
window = _hann_window()
|
||||
|
||||
power = _stft_frames(x, window)
|
||||
# Drop last frame to match reference STFT behaviour
|
||||
power = power[:-1]
|
||||
return _apply_mel_and_log(power).T
|
||||
|
||||
|
||||
def compute_mel_streaming(
|
||||
chunk: np.ndarray,
|
||||
overlap: np.ndarray | None,
|
||||
) -> tuple[mx.array, np.ndarray]:
|
||||
"""Incrementally compute log-mel for a new audio chunk.
|
||||
|
||||
Args:
|
||||
chunk: New audio samples (float32 numpy).
|
||||
overlap: The last ``WINDOW_SIZE - HOP`` = 240 samples from the
|
||||
previous call, or *None* on the first call (uses zero-padding).
|
||||
|
||||
Returns:
|
||||
``(mel, new_overlap)`` where *mel* is ``[MEL_BANDS, N]`` and
|
||||
*new_overlap* is the 240-sample tail for the next call.
|
||||
"""
|
||||
tail_len = WINDOW_SIZE - HOP # 240
|
||||
|
||||
if overlap is not None:
|
||||
combined = np.concatenate([overlap, chunk])
|
||||
else:
|
||||
combined = np.concatenate([np.zeros(WINDOW_SIZE // 2, dtype=np.float32), chunk])
|
||||
|
||||
new_overlap = combined[-tail_len:].copy()
|
||||
|
||||
x = mx.array(combined)
|
||||
window = _hann_window()
|
||||
power = _stft_frames(x, window)
|
||||
|
||||
if power.shape[0] == 0:
|
||||
return mx.zeros((MEL_BANDS, 0)), new_overlap
|
||||
|
||||
return _apply_mel_and_log(power).T, new_overlap
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio: np.ndarray,
|
||||
n_left: int = LEFT_PAD_TOKENS,
|
||||
n_right: int = RIGHT_PAD_TOKENS,
|
||||
) -> np.ndarray:
|
||||
"""Pad audio with silence for batch (non-streaming) inference."""
|
||||
left = n_left * SAMPLES_PER_TOKEN
|
||||
align = (SAMPLES_PER_TOKEN - (len(audio) % SAMPLES_PER_TOKEN)) % SAMPLES_PER_TOKEN
|
||||
right = align + n_right * SAMPLES_PER_TOKEN
|
||||
return np.pad(audio, (left, right))
|
||||
521
whisperlivekit/voxtral_mlx_asr.py
Normal file
521
whisperlivekit/voxtral_mlx_asr.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.
|
||||
|
||||
Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
|
||||
(streaming processor) that plug into WhisperLiveKit's audio processing
|
||||
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.
|
||||
|
||||
Unlike the HuggingFace backend, this runs the full inference loop in-process
|
||||
(no background thread / queue) — MLX operations on Apple Silicon are fast
|
||||
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
from whisperlivekit.voxtral_mlx.loader import load_voxtral_model, DEFAULT_MODEL_ID
|
||||
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
|
||||
from whisperlivekit.voxtral_mlx.spectrogram import (
|
||||
SAMPLES_PER_TOKEN,
|
||||
LEFT_PAD_TOKENS,
|
||||
RIGHT_PAD_TOKENS,
|
||||
compute_mel_streaming,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decoder sliding-window size (matches the model's training configuration).
|
||||
_DECODER_WINDOW = 8192
|
||||
|
||||
|
||||
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
|
||||
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
|
||||
pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
|
||||
ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
|
||||
return ids, n_delay
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model holder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VoxtralMLXASR:
|
||||
"""Lightweight model holder — loads the MLX Voxtral model once and keeps
|
||||
it alive for the lifetime of the server."""
|
||||
|
||||
sep = " "
|
||||
SAMPLING_RATE = 16_000
|
||||
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
lan = kwargs.get("lan", "auto")
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
|
||||
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||
if not model_path:
|
||||
model_size = kwargs.get("model_size", "")
|
||||
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||
model_path = model_size
|
||||
else:
|
||||
model_path = DEFAULT_MODEL_ID
|
||||
|
||||
t0 = time.time()
|
||||
logger.info("Loading Voxtral MLX model '%s' ...", model_path)
|
||||
self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
|
||||
logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)
|
||||
|
||||
self.backend_choice = "voxtral-mlx"
|
||||
|
||||
def transcribe(self, audio):
|
||||
pass # all work happens in the online processor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Online processor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VoxtralMLXOnlineProcessor:
|
||||
"""Streaming processor that incrementally encodes audio and decodes text
|
||||
using the MLX Voxtral model.
|
||||
|
||||
Lifecycle (called by ``AudioProcessor.transcription_processor``):
|
||||
|
||||
insert_audio_chunk(pcm, time) → process_iter() → get_buffer()
|
||||
... repeat ...
|
||||
start_silence() / end_silence()
|
||||
finish()
|
||||
"""
|
||||
|
||||
SAMPLING_RATE = 16_000
|
||||
|
||||
def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer: list = []
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
|
||||
self._model = asr.model
|
||||
self._tokenizer = asr.tokenizer
|
||||
|
||||
# Pre-compute prompt tokens and delay conditioning (constant across utterances).
|
||||
self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
|
||||
self._prefix_len = len(self._prompt_ids)
|
||||
|
||||
self._delay_cond = self._model.delay_embedding(
|
||||
mx.array([self._n_delay], dtype=mx.float32)
|
||||
)
|
||||
mx.eval(self._delay_cond)
|
||||
|
||||
self._prompt_embeds = self._model.decoder.embed(
|
||||
mx.array([self._prompt_ids])
|
||||
)[0] # [prefix_len, dim]
|
||||
mx.eval(self._prompt_embeds)
|
||||
|
||||
self._eos_id = self._tokenizer.eos_id
|
||||
self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
|
||||
# The streaming model has an inherent delay: text for audio at position P
|
||||
# is generated at decoder position P + n_delay. Compensate timestamps.
|
||||
self._delay_secs = self._n_delay * self._secs_per_token
|
||||
|
||||
self._reset_state()
|
||||
|
||||
# -- state management --
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset all incremental state for a fresh utterance."""
|
||||
# Audio accumulation
|
||||
self._pending = np.zeros(0, dtype=np.float32)
|
||||
# Mel overlap
|
||||
self._mel_overlap: np.ndarray | None = None
|
||||
# Encoder incremental state
|
||||
self._conv_tail1 = None
|
||||
self._conv_tail2 = None
|
||||
self._enc_cache = None
|
||||
self._ds_remainder = None
|
||||
# Audio embeddings not yet decoded
|
||||
self._audio_embeds: mx.array | None = None
|
||||
# Decoder state
|
||||
self._dec_cache: list[SlidingKVCache] | None = None
|
||||
self._last_token: mx.array | None = None
|
||||
# Bookkeeping
|
||||
self._samples_encoded = 0
|
||||
self._positions_decoded = 0
|
||||
self._prefilled = False
|
||||
self._first_chunk = True
|
||||
# Text state
|
||||
self._full_text = ""
|
||||
self._n_text_tokens = 0
|
||||
self._n_committed_words = 0
|
||||
self._time_offset = 0.0
|
||||
# Per-word audio position tracking: decoder position (relative to prefix)
|
||||
# where each word in _full_text started and ended
|
||||
self._word_audio_starts: list[int] = [] # audio pos where word i started
|
||||
self._word_audio_ends: list[int] = [] # audio pos where word i last produced a token
|
||||
self._current_word_pos: Optional[int] = None # audio pos of current (incomplete) word's first token
|
||||
|
||||
# -- audio ingestion --
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
self.end = audio_stream_end_time
|
||||
self._pending = np.append(self._pending, audio)
|
||||
self.audio_buffer = self._pending
|
||||
|
||||
# -- core processing --
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
try:
|
||||
return self._step(is_last)
|
||||
except Exception as e:
|
||||
logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
|
||||
return [], self.end
|
||||
|
||||
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||
# 1. Encode any new audio
|
||||
self._encode_pending()
|
||||
|
||||
if self._audio_embeds is None:
|
||||
return [], self.end
|
||||
|
||||
# 2. Compute how many positions we can safely decode
|
||||
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
|
||||
n_available = self._audio_embeds.shape[0]
|
||||
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
||||
|
||||
if n_decodable <= 0:
|
||||
return [], self.end
|
||||
|
||||
# 3. Prefill if needed
|
||||
if not self._prefilled:
|
||||
if self._positions_decoded + n_available < self._prefix_len:
|
||||
return [], self.end
|
||||
self._do_prefill()
|
||||
# Re-check after consuming prefix embeddings
|
||||
n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
|
||||
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
||||
|
||||
if n_decodable <= 0 or self._audio_embeds is None:
|
||||
return [], self.end
|
||||
|
||||
# 4. Decode available positions
|
||||
hit_eos = self._decode_positions(n_decodable)
|
||||
|
||||
if hit_eos:
|
||||
# Flush words, reset for next utterance
|
||||
words = self._flush_all_words()
|
||||
logger.debug(
|
||||
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
|
||||
"samples_encoded=%d (%.2fs), text='%s'",
|
||||
len(words), self._samples_encoded,
|
||||
self._samples_encoded / self.SAMPLING_RATE,
|
||||
self._full_text[-60:] if self._full_text else "",
|
||||
)
|
||||
saved_offset = self._time_offset
|
||||
self._reset_state()
|
||||
self._time_offset = saved_offset
|
||||
return words, self.end
|
||||
|
||||
# 5. Extract committed words (all but the last, which may still grow)
|
||||
return self._extract_committed_words(), self.end
|
||||
|
||||
def _encode_pending(self):
|
||||
"""Feed pending audio through the incremental encoder."""
|
||||
available = len(self._pending)
|
||||
if available < SAMPLES_PER_TOKEN:
|
||||
return
|
||||
|
||||
if self._first_chunk:
|
||||
# First chunk: prepend silence for left-padding
|
||||
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||
left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
|
||||
chunk = np.concatenate([left_pad, self._pending[:n_take]])
|
||||
self._pending = self._pending[n_take:]
|
||||
self._samples_encoded += n_take
|
||||
self._first_chunk = False
|
||||
else:
|
||||
n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
|
||||
chunk = self._pending[:n_take]
|
||||
self._pending = self._pending[n_take:]
|
||||
self._samples_encoded += n_take
|
||||
|
||||
mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)
|
||||
|
||||
embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
|
||||
self._model.encode_incremental(
|
||||
mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
|
||||
)
|
||||
)
|
||||
|
||||
if embeds is not None:
|
||||
mx.eval(embeds)
|
||||
if self._audio_embeds is not None:
|
||||
self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
|
||||
else:
|
||||
self._audio_embeds = embeds
|
||||
|
||||
self.audio_buffer = self._pending
|
||||
|
||||
def _do_prefill(self):
|
||||
"""Run the decoder prefill pass over the prompt + first audio embeddings."""
|
||||
n_dec_layers = len(self._model.decoder.blocks)
|
||||
self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]
|
||||
|
||||
prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
|
||||
prefix_embeds = prefix_embeds[None, :, :] # [1, prefix_len, dim]
|
||||
|
||||
logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
|
||||
mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])
|
||||
|
||||
self._last_token = self._sample(logits)
|
||||
mx.async_eval(self._last_token)
|
||||
|
||||
# Remove consumed prefix embeddings
|
||||
self._audio_embeds = self._audio_embeds[self._prefix_len :]
|
||||
if self._audio_embeds.shape[0] == 0:
|
||||
self._audio_embeds = None
|
||||
self._positions_decoded = self._prefix_len
|
||||
self._prefilled = True
|
||||
|
||||
def _decode_positions(self, n: int) -> bool:
|
||||
"""Autoregressively decode *n* positions. Returns True on EOS."""
|
||||
base_pos = self._positions_decoded # absolute position before this batch
|
||||
for i in range(n):
|
||||
tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
|
||||
combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
|
||||
logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
|
||||
next_tok = self._sample(logits)
|
||||
mx.async_eval(next_tok)
|
||||
|
||||
token_id = self._last_token.item()
|
||||
if token_id == self._eos_id:
|
||||
# Close the current word if one is being built
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(base_pos + i - self._prefix_len)
|
||||
self._current_word_pos = None
|
||||
self._trim_embeds(i)
|
||||
self._positions_decoded += i
|
||||
return True
|
||||
|
||||
text = self._tokenizer.decode(
|
||||
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
|
||||
if text:
|
||||
audio_pos = base_pos + i - self._prefix_len
|
||||
|
||||
# Detect word boundary: new word starts with space or is the very first text
|
||||
if text.lstrip() != text or not self._full_text:
|
||||
# Close previous word if exists
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(audio_pos)
|
||||
# Start new word
|
||||
self._word_audio_starts.append(audio_pos)
|
||||
self._current_word_pos = audio_pos
|
||||
elif self._current_word_pos is None:
|
||||
# First token of first word (no leading space)
|
||||
self._word_audio_starts.append(audio_pos)
|
||||
self._current_word_pos = audio_pos
|
||||
|
||||
self._full_text += text
|
||||
self._n_text_tokens += 1
|
||||
|
||||
if i > 0 and i % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
self._last_token = next_tok
|
||||
|
||||
self._positions_decoded += n
|
||||
self._trim_embeds(n)
|
||||
return False
|
||||
|
||||
def _trim_embeds(self, n_consumed: int):
|
||||
if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
|
||||
self._audio_embeds = self._audio_embeds[n_consumed:]
|
||||
else:
|
||||
self._audio_embeds = None
|
||||
|
||||
def _sample(self, logits: mx.array) -> mx.array:
|
||||
return mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||
|
||||
# -- word extraction --
|
||||
|
||||
def _audio_pos_to_time(self, pos: int) -> float:
|
||||
"""Convert an audio position (relative to prefix end) to seconds."""
|
||||
return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)
|
||||
|
||||
def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
|
||||
"""Compute (start, end) time for a word using tracked word positions."""
|
||||
starts = self._word_audio_starts
|
||||
ends = self._word_audio_ends
|
||||
|
||||
if not starts:
|
||||
return self._time_offset, self._time_offset
|
||||
|
||||
# Get start position for this word
|
||||
if word_idx < len(starts):
|
||||
t0 = self._audio_pos_to_time(starts[word_idx])
|
||||
else:
|
||||
# Fallback: estimate from last known position
|
||||
last_pos = ends[-1] if ends else starts[-1]
|
||||
t0 = self._audio_pos_to_time(last_pos + 1)
|
||||
|
||||
# Get end position: use the start of the next word, or the end of this word
|
||||
if word_idx + 1 < len(starts):
|
||||
t1 = self._audio_pos_to_time(starts[word_idx + 1])
|
||||
elif word_idx < len(ends):
|
||||
t1 = self._audio_pos_to_time(ends[word_idx] + 1)
|
||||
else:
|
||||
# Last word, still being built: use last known position + 1 token
|
||||
last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
|
||||
t1 = self._audio_pos_to_time(last_pos + 1)
|
||||
|
||||
return t0, t1
|
||||
|
||||
def _extract_committed_words(self) -> List[ASRToken]:
|
||||
"""Return complete words (all except the last which may still grow)."""
|
||||
if not self._full_text:
|
||||
return []
|
||||
words = self._full_text.split()
|
||||
tokens: List[ASRToken] = []
|
||||
n_total = max(len(words), 1)
|
||||
|
||||
while len(words) > self._n_committed_words + 1:
|
||||
w = words[self._n_committed_words]
|
||||
idx = self._n_committed_words
|
||||
t0, t1 = self._word_time_range(idx, n_total)
|
||||
label = w if idx == 0 else " " + w
|
||||
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
||||
self._n_committed_words += 1
|
||||
|
||||
return tokens
|
||||
|
||||
def _flush_all_words(self) -> List[ASRToken]:
|
||||
"""Flush every word including the last partial one."""
|
||||
if not self._full_text:
|
||||
return []
|
||||
words = self._full_text.split()
|
||||
tokens: List[ASRToken] = []
|
||||
n_total = max(len(words), 1)
|
||||
|
||||
while self._n_committed_words < len(words):
|
||||
w = words[self._n_committed_words]
|
||||
idx = self._n_committed_words
|
||||
t0, t1 = self._word_time_range(idx, n_total)
|
||||
label = w if idx == 0 else " " + w
|
||||
tokens.append(ASRToken(start=t0, end=t1, text=label))
|
||||
self._n_committed_words += 1
|
||||
|
||||
return tokens
|
||||
|
||||
# -- interface methods --
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
if not self._full_text:
|
||||
return Transcript(start=None, end=None, text="")
|
||||
words = self._full_text.split()
|
||||
remaining = words[self._n_committed_words :]
|
||||
if remaining:
|
||||
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
|
||||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
words = self._flush_all_words()
|
||||
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
||||
return words, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self._time_offset += silence_duration
|
||||
self.end += silence_duration
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
self.start_silence()
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
pass
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
logger.debug(
|
||||
"[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
|
||||
"samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
|
||||
len(self._pending),
|
||||
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||
self._samples_encoded,
|
||||
self._positions_decoded,
|
||||
self._prefilled,
|
||||
self._full_text[-80:] if self._full_text else "",
|
||||
)
|
||||
|
||||
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
||||
remainder = len(self._pending) % SAMPLES_PER_TOKEN
|
||||
if remainder > 0:
|
||||
align_pad = SAMPLES_PER_TOKEN - remainder
|
||||
else:
|
||||
align_pad = 0
|
||||
|
||||
# Add alignment + right-padding silence
|
||||
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
||||
if total_pad > 0:
|
||||
self._pending = np.append(
|
||||
self._pending, np.zeros(total_pad, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Encode remaining audio (including right-padding)
|
||||
self._encode_pending()
|
||||
|
||||
logger.debug(
|
||||
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
||||
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||
len(self._pending),
|
||||
)
|
||||
|
||||
hit_eos = False
|
||||
|
||||
# Decode everything that's left from right-padding
|
||||
if self._audio_embeds is not None and self._prefilled:
|
||||
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
|
||||
logger.debug(
|
||||
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
|
||||
hit_eos, self._full_text[-80:] if self._full_text else "",
|
||||
)
|
||||
|
||||
# Flush last token if it wasn't EOS
|
||||
if self._last_token is not None:
|
||||
tid = self._last_token.item()
|
||||
if tid != self._eos_id:
|
||||
text = self._tokenizer.decode(
|
||||
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
if text:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
# Check if this starts a new word
|
||||
if text.lstrip() != text or not self._full_text:
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
elif self._current_word_pos is None:
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
self._full_text += text
|
||||
self._n_text_tokens += 1
|
||||
|
||||
# Close the last word if still open
|
||||
if self._current_word_pos is not None:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._current_word_pos = None
|
||||
|
||||
words = self._flush_all_words()
|
||||
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
|
||||
return words, self.end
|
||||
Reference in New Issue
Block a user