From 59db08e9612bfb9fdb9183d1e42c73d771a4f582 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 25 Nov 2025 23:52:00 +0100 Subject: [PATCH] loader for full mlx --- whisperlivekit/simul_whisper/mlx_encoder.py | 36 +++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/whisperlivekit/simul_whisper/mlx_encoder.py b/whisperlivekit/simul_whisper/mlx_encoder.py index c9b0cd5..7c64079 100644 --- a/whisperlivekit/simul_whisper/mlx_encoder.py +++ b/whisperlivekit/simul_whisper/mlx_encoder.py @@ -68,4 +68,40 @@ def load_mlx_encoder( model.update(encoder_weights) mx.eval(model.parameters()) + return model + + +def load_mlx_model( + path_or_hf_repo: str, + dtype: mx.Dtype = mx.float32, +) -> whisper.Whisper: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) + + with open(str(model_path / "config.json"), "r") as f: + config = json.loads(f.read()) + config.pop("model_type", None) + quantization = config.pop("quantization", None) + + model_args = whisper.ModelDimensions(**config) + + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) + + model = whisper.Whisper(model_args, dtype) + + if quantization is not None: + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) + + weights = tree_unflatten(list(weights.items())) + + model.update(weights) + mx.eval(model.parameters()) return model \ No newline at end of file