mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency speech transcription on Apple Silicon. Avoids the transformers/torch overhead and runs at 0.18-0.32x real-time factor. - voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder - voxtral_mlx/loader.py: model loading with 6-bit quantized weights - voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX - voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
283 lines
11 KiB
Python
283 lines
11 KiB
Python
"""
|
|
Model weight loading for the MLX Voxtral Realtime backend.
|
|
|
|
Supports two on-disk formats:
|
|
1. **Converted** (``config.json`` + ``model.safetensors``): ready-to-load,
|
|
with optional quantisation metadata.
|
|
2. **Original Mistral** (``params.json`` + ``consolidated.safetensors``):
|
|
requires weight renaming and conv-weight transposition.
|
|
|
|
The public entry point is :func:`load_voxtral_model` which returns the
|
|
model, tokenizer, and raw config dict.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from .model import VoxtralMLXModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL_ID = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Downloading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ALLOWED_PATTERNS = [
|
|
"consolidated.safetensors",
|
|
"model*.safetensors",
|
|
"model.safetensors.index.json",
|
|
"params.json",
|
|
"config.json",
|
|
"tekken.json",
|
|
]
|
|
|
|
|
|
def download_weights(model_id: str = DEFAULT_MODEL_ID) -> Path:
|
|
"""Download model files from HuggingFace Hub and return the local path."""
|
|
return Path(snapshot_download(model_id, allow_patterns=_ALLOWED_PATTERNS))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Weight name remapping (Mistral → our naming)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_NAME_RULES: list[tuple[str, str]] = [
|
|
# Encoder convolutions
|
|
(r"whisper_encoder\.conv_layers\.0\.conv\.(.*)", r"encoder.conv1.\1"),
|
|
(r"whisper_encoder\.conv_layers\.1\.conv\.(.*)", r"encoder.conv2.\1"),
|
|
# Encoder transformer blocks
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wq\.(.*)",
|
|
r"encoder.blocks.\1.self_attn.q_proj.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wk\.(.*)",
|
|
r"encoder.blocks.\1.self_attn.k_proj.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wv\.(.*)",
|
|
r"encoder.blocks.\1.self_attn.v_proj.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(.*)",
|
|
r"encoder.blocks.\1.self_attn.out_proj.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(.*)",
|
|
r"encoder.blocks.\1.pre_attn_norm.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(.*)",
|
|
r"encoder.blocks.\1.ffn.gate.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(.*)",
|
|
r"encoder.blocks.\1.ffn.down.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(.*)",
|
|
r"encoder.blocks.\1.ffn.up.\2"),
|
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(.*)",
|
|
r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
|
(r"whisper_encoder\.transformer\.norm\.(.*)", r"encoder.final_norm.\1"),
|
|
# Adapter
|
|
(r"audio_language_projection\.0\.weight", r"adapter.linear1.weight"),
|
|
(r"audio_language_projection\.2\.weight", r"adapter.linear2.weight"),
|
|
# Decoder embedding
|
|
(r"tok_embeddings\.weight", r"decoder.token_embedding.weight"),
|
|
# Decoder blocks
|
|
(r"layers\.(\d+)\.attention\.wq\.weight",
|
|
r"decoder.blocks.\1.self_attn.q_proj.weight"),
|
|
(r"layers\.(\d+)\.attention\.wk\.weight",
|
|
r"decoder.blocks.\1.self_attn.k_proj.weight"),
|
|
(r"layers\.(\d+)\.attention\.wv\.weight",
|
|
r"decoder.blocks.\1.self_attn.v_proj.weight"),
|
|
(r"layers\.(\d+)\.attention\.wo\.weight",
|
|
r"decoder.blocks.\1.self_attn.out_proj.weight"),
|
|
(r"layers\.(\d+)\.attention_norm\.weight",
|
|
r"decoder.blocks.\1.pre_attn_norm.weight"),
|
|
(r"layers\.(\d+)\.feed_forward\.w1\.weight",
|
|
r"decoder.blocks.\1.ffn.gate.weight"),
|
|
(r"layers\.(\d+)\.feed_forward\.w2\.weight",
|
|
r"decoder.blocks.\1.ffn.down.weight"),
|
|
(r"layers\.(\d+)\.feed_forward\.w3\.weight",
|
|
r"decoder.blocks.\1.ffn.up.weight"),
|
|
(r"layers\.(\d+)\.ffn_norm\.weight",
|
|
r"decoder.blocks.\1.pre_ffn_norm.weight"),
|
|
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.0\.weight",
|
|
r"decoder.blocks.\1.adaptive_scale.proj_in.weight"),
|
|
(r"layers\.(\d+)\.ada_rms_norm_t_cond\.2\.weight",
|
|
r"decoder.blocks.\1.adaptive_scale.proj_out.weight"),
|
|
# Decoder final norm
|
|
(r"norm\.weight", r"decoder.final_norm.weight"),
|
|
]
|
|
|
|
_PREFIX_STRIP = re.compile(
|
|
r"^(mm_streams_embeddings\.embedding_module|mm_whisper_embeddings)\."
|
|
)
|
|
|
|
|
|
def _translate_weight_name(name: str) -> str | None:
|
|
name = _PREFIX_STRIP.sub("", name)
|
|
for pattern, replacement in _NAME_RULES:
|
|
result, n = re.subn(f"^{pattern}$", replacement, name)
|
|
if n:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _is_conv_weight(name: str) -> bool:
|
|
return ("conv1.weight" in name or "conv2.weight" in name) and "bias" not in name
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Converted-format weight remapping (voxmlx names → our names)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_CONVERTED_RULES: list[tuple[str, str]] = [
|
|
# Adapter
|
|
(r"adapter\.w_in\.(.*)", r"adapter.linear1.\1"),
|
|
(r"adapter\.w_out\.(.*)", r"adapter.linear2.\1"),
|
|
# Encoder transformer blocks
|
|
(r"encoder\.layers\.(\d+)\.attention\.(.*)", r"encoder.blocks.\1.self_attn.\2"),
|
|
(r"encoder\.layers\.(\d+)\.attn_norm\.(.*)", r"encoder.blocks.\1.pre_attn_norm.\2"),
|
|
(r"encoder\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"encoder.blocks.\1.ffn.gate.\2"),
|
|
(r"encoder\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"encoder.blocks.\1.ffn.down.\2"),
|
|
(r"encoder\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"encoder.blocks.\1.ffn.up.\2"),
|
|
(r"encoder\.layers\.(\d+)\.ffn_norm\.(.*)", r"encoder.blocks.\1.pre_ffn_norm.\2"),
|
|
(r"encoder\.norm\.(.*)", r"encoder.final_norm.\1"),
|
|
# Decoder embedding
|
|
(r"language_model\.embed_tokens\.(.*)", r"decoder.token_embedding.\1"),
|
|
# Decoder blocks
|
|
(r"language_model\.layers\.(\d+)\.attention\.(.*)", r"decoder.blocks.\1.self_attn.\2"),
|
|
(r"language_model\.layers\.(\d+)\.attn_norm\.(.*)", r"decoder.blocks.\1.pre_attn_norm.\2"),
|
|
(r"language_model\.layers\.(\d+)\.mlp\.gate_proj\.(.*)", r"decoder.blocks.\1.ffn.gate.\2"),
|
|
(r"language_model\.layers\.(\d+)\.mlp\.down_proj\.(.*)", r"decoder.blocks.\1.ffn.down.\2"),
|
|
(r"language_model\.layers\.(\d+)\.mlp\.up_proj\.(.*)", r"decoder.blocks.\1.ffn.up.\2"),
|
|
(r"language_model\.layers\.(\d+)\.ffn_norm\.(.*)", r"decoder.blocks.\1.pre_ffn_norm.\2"),
|
|
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_in\.(.*)",
|
|
r"decoder.blocks.\1.adaptive_scale.proj_in.\2"),
|
|
(r"language_model\.layers\.(\d+)\.ada_norm\.linear_out\.(.*)",
|
|
r"decoder.blocks.\1.adaptive_scale.proj_out.\2"),
|
|
(r"language_model\.norm\.(.*)", r"decoder.final_norm.\1"),
|
|
]
|
|
|
|
# Also remap o_proj → out_proj in both encoder and decoder
|
|
_POST_RENAME = [
|
|
(r"\.o_proj\.", r".out_proj."),
|
|
]
|
|
|
|
|
|
def _remap_converted_name(name: str) -> str:
|
|
"""Translate a converted-format weight name to our naming convention."""
|
|
for pattern, replacement in _CONVERTED_RULES:
|
|
result, n = re.subn(f"^{pattern}$", replacement, name)
|
|
if n:
|
|
name = result
|
|
break
|
|
for pattern, replacement in _POST_RENAME:
|
|
name = re.sub(pattern, replacement, name)
|
|
return name
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Loading strategies
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _has_converted_layout(path: Path) -> bool:
|
|
return (path / "config.json").exists() and not (path / "consolidated.safetensors").exists()
|
|
|
|
|
|
def _load_converted_weights(path: Path):
|
|
with open(path / "config.json") as f:
|
|
config = json.load(f)
|
|
|
|
model = VoxtralMLXModel(config)
|
|
|
|
quant = config.get("quantization")
|
|
if quant is not None:
|
|
gs = quant["group_size"]
|
|
nn.quantize(
|
|
model,
|
|
group_size=gs,
|
|
bits=quant["bits"],
|
|
class_predicate=lambda _p, m: (
|
|
hasattr(m, "to_quantized") and m.weight.shape[-1] % gs == 0
|
|
),
|
|
)
|
|
|
|
index_file = path / "model.safetensors.index.json"
|
|
if index_file.exists():
|
|
with open(index_file) as f:
|
|
shard_map = json.load(f)
|
|
shard_files = sorted(set(shard_map["weight_map"].values()))
|
|
weights = {}
|
|
for sf in shard_files:
|
|
weights.update(mx.load(str(path / sf)))
|
|
else:
|
|
weights = mx.load(str(path / "model.safetensors"))
|
|
|
|
remapped = {_remap_converted_name(k): v for k, v in weights.items()}
|
|
model.load_weights(list(remapped.items()))
|
|
mx.eval(model.parameters())
|
|
return model, config
|
|
|
|
|
|
def _load_original_weights(path: Path):
|
|
with open(path / "params.json") as f:
|
|
config = json.load(f)
|
|
|
|
model = VoxtralMLXModel(config)
|
|
|
|
raw = mx.load(str(path / "consolidated.safetensors"))
|
|
mapped: dict[str, mx.array] = {}
|
|
skipped: list[str] = []
|
|
|
|
for name, tensor in raw.items():
|
|
if name == "output.weight":
|
|
continue
|
|
new_name = _translate_weight_name(name)
|
|
if new_name is None:
|
|
skipped.append(name)
|
|
continue
|
|
# Conv weights: PyTorch [C_out, C_in, K] → MLX [C_out, K, C_in]
|
|
if _is_conv_weight(new_name):
|
|
tensor = mx.swapaxes(tensor, 1, 2)
|
|
mapped[new_name] = tensor
|
|
|
|
if skipped:
|
|
logger.warning("Skipped %d unrecognised weight keys (first 5: %s)", len(skipped), skipped[:5])
|
|
|
|
model.load_weights(list(mapped.items()))
|
|
mx.eval(model.parameters())
|
|
return model, config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tokenizer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_tokenizer(model_dir: Path):
|
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
return Tekkenizer.from_file(str(model_dir / "tekken.json"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_voxtral_model(path_or_id: str = DEFAULT_MODEL_ID):
|
|
"""Load a Voxtral Realtime model and its tokenizer.
|
|
|
|
Args:
|
|
path_or_id: Local directory path **or** a HuggingFace model ID.
|
|
|
|
Returns:
|
|
``(model, tokenizer, config)``
|
|
"""
|
|
p = Path(path_or_id)
|
|
if not p.exists():
|
|
p = download_weights(path_or_id)
|
|
|
|
if _has_converted_layout(p):
|
|
model, config = _load_converted_weights(p)
|
|
else:
|
|
model, config = _load_original_weights(p)
|
|
|
|
tokenizer = _load_tokenizer(p)
|
|
logger.info("Voxtral MLX model loaded from %s", p)
|
|
return model, tokenizer, config
|