This commit is contained in:
Quentin Fuxa
2025-12-05 17:54:14 +01:00
parent 62444ce746
commit 4d9332ce7d

View File

@@ -108,7 +108,7 @@ def available_models() -> List[str]:
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
next to the given checkpoint, usefull for distilled models/MLX models.
"""
candidates = []
if os.path.isdir(path):
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
@@ -520,7 +557,12 @@ def load_model(
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None:
@@ -546,8 +588,13 @@ def load_model(
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device)