mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import json
|
|
from pathlib import Path
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
from mlx.utils import tree_unflatten
|
|
from mlx_whisper import whisper
|
|
|
|
mlx_model_mapping = {
|
|
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
|
"base.en": "mlx-community/whisper-base.en-mlx",
|
|
"base": "mlx-community/whisper-base-mlx",
|
|
"small.en": "mlx-community/whisper-small.en-mlx",
|
|
"small": "mlx-community/whisper-small-mlx",
|
|
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
|
"medium": "mlx-community/whisper-medium-mlx",
|
|
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
|
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
|
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
|
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
|
"large": "mlx-community/whisper-large-mlx",
|
|
}
|
|
|
|
def load_mlx_encoder(
|
|
path_or_hf_repo: str,
|
|
dtype: mx.Dtype = mx.float32,
|
|
) -> whisper.Whisper:
|
|
model_path = Path(path_or_hf_repo)
|
|
if not model_path.exists():
|
|
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
|
|
|
with open(str(model_path / "config.json"), "r") as f:
|
|
config = json.loads(f.read())
|
|
config.pop("model_type", None)
|
|
quantization = config.pop("quantization", None)
|
|
|
|
model_args = whisper.ModelDimensions(**config)
|
|
|
|
wf = model_path / "weights.safetensors"
|
|
if not wf.exists():
|
|
wf = model_path / "weights.npz"
|
|
weights = mx.load(str(wf))
|
|
|
|
model = whisper.Whisper(model_args, dtype)
|
|
|
|
if quantization is not None:
|
|
class_predicate = (
|
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
|
and f"{p}.scales" in weights
|
|
)
|
|
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
|
|
|
weights = tree_unflatten(list(weights.items()))
|
|
|
|
# we only want to load the encoder weights here.
|
|
# Size examples: for tiny.en,
|
|
# Decoder weights: 59110771 bytes
|
|
# Encoder weights: 15268874 bytes
|
|
|
|
|
|
encoder_weights = {}
|
|
encoder_weights['encoder'] = weights['encoder']
|
|
del(weights)
|
|
|
|
|
|
|
|
model.update(encoder_weights)
|
|
mx.eval(model.parameters())
|
|
return model |