From 4d9332ce7da00a162f2201d071be63906f6a1ae1 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 5 Dec 2025 17:54:14 +0100 Subject: [PATCH] fixes #299 --- whisperlivekit/whisper/__init__.py | 53 ++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index ce68de9..00cf761 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -108,7 +108,7 @@ def available_models() -> List[str]: def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]: """ attempt to infer ModelDimensions from a HF style config.json located - next to the given checkpoint, usefull for distilled models + next to the given checkpoint, usefull for distilled models/MLX models. """ candidates = [] if os.path.isdir(path): @@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]: with open(candidate, "r", encoding="utf-8") as f: config = json.load(f) + # native Whisper format + native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head", + "n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state", + "n_text_head", "n_text_layer"] + if all(k in config for k in native_keys): + return ModelDimensions( + n_mels=config["n_mels"], + n_audio_ctx=config["n_audio_ctx"], + n_audio_state=config["n_audio_state"], + n_audio_head=config["n_audio_head"], + n_audio_layer=config["n_audio_layer"], + n_vocab=config["n_vocab"], + n_text_ctx=config["n_text_ctx"], + n_text_state=config["n_text_state"], + n_text_head=config["n_text_head"], + n_text_layer=config["n_text_layer"], + ) + + # HuggingFace format try: return ModelDimensions( n_mels=config["num_mel_bins"], @@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor return converted if converted else state_dict +def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Converts an mlx whisper checkpoint to a default openai whisper one + """ + if not any("mlp1" in k or "mlp2" in k for k in state_dict): + return state_dict + + converted = {} + for key, value in state_dict.items(): + if key == "alignment_heads": + continue + + new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.") + converted[new_key] = value + + return converted + + def _load_lora_state(lora_path: str): safe_path = os.path.join(lora_path, "adapter_model.safetensors") bin_path = os.path.join(lora_path, "adapter_model.bin") @@ -520,7 +557,12 @@ def load_model( state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint + + if alignment_heads is None and "alignment_heads" in state_dict: + alignment_heads = state_dict["alignment_heads"] + state_dict = _convert_hf_state_dict(state_dict) + state_dict = _convert_mlx_state_dict(state_dict) _apply_lora_adapter(state_dict, lora_path) if dims_cfg is not None: @@ -546,8 +588,13 @@ def load_model( model.load_state_dict(state_dict) if alignment_heads is not None: - model.set_alignment_heads(alignment_heads) - + if isinstance(alignment_heads, bytes): + model.set_alignment_heads(alignment_heads) + elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper + mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool) + for layer, head in alignment_heads.tolist(): + mask[layer, head] = True + model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) return model.to(device)