From 1bbbb7903caf010231bec7888b8bf2c4ee38927a Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 16 Nov 2025 18:44:35 +0100 Subject: [PATCH] lora loader in shared whisper core --- whisperlivekit/whisper/__init__.py | 94 +++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index 751020a..8ae68c9 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -4,11 +4,12 @@ import json import os import urllib import warnings -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union import torch from tqdm import tqdm from pathlib import Path +from torch import Tensor from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language @@ -233,13 +234,97 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor return converted if converted else state_dict +def _load_lora_state(lora_path: str): + safe_path = os.path.join(lora_path, "adapter_model.safetensors") + bin_path = os.path.join(lora_path, "adapter_model.bin") + if os.path.isfile(safe_path): + try: + from safetensors.torch import load_file + except ImportError as exc: + raise ImportError( + "Loading LoRA adapters stored as .safetensors requires the `safetensors` package." + ) from exc + return load_file(safe_path) + if os.path.isfile(bin_path): + return torch.load(bin_path, map_location="cpu") + raise FileNotFoundError( + f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin." + ) + + +def _collapse_hf_module_name(module: str): + if module.startswith("base_model."): + module = module[len("base_model.") :] + if module.startswith("model.model."): + module = module[len("model.") :] + if not module.startswith("model."): + module = f"model.{module}" + return module + + +def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]): + if not lora_path: + return + + config_path = os.path.join(lora_path, "adapter_config.json") + if not os.path.isfile(config_path): + raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}") + with open(config_path, "r", encoding="utf-8") as handle: + config = json.load(handle) + if config.get("peft_type") != "LORA": + raise ValueError("Only LoRA adapters are supported.") + + r = config.get("r") + alpha = config.get("lora_alpha") or config.get("alpha") + if not r or not alpha: + raise ValueError("LoRA config must include `r` and `lora_alpha`.") + scaling = alpha / r + + adapter_state = _load_lora_state(lora_path) + lora_layers: Dict[str, Dict[str, Tensor]] = {} + for key, tensor in adapter_state.items(): + if key.endswith("lora_A.weight"): + module = key[: -len(".lora_A.weight")] + lora_layers.setdefault(module, {})["A"] = tensor + elif key.endswith("lora_B.weight"): + module = key[: -len(".lora_B.weight")] + lora_layers.setdefault(module, {})["B"] = tensor + + if not lora_layers: + raise ValueError(f"No LoRA tensors found in {lora_path}") + + for module, parts in lora_layers.items(): + if "A" not in parts or "B" not in parts: + raise ValueError(f"Incomplete LoRA tensors for module '{module}'") + + hf_module = _collapse_hf_module_name(module) + hf_weight_key = f"{hf_module}.weight" + + delta = parts["B"] @ parts["A"] + delta = delta * scaling + + converted = _convert_hf_state_dict({hf_weight_key: delta}) + if not converted: + raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.") + target_name, delta_tensor = next(iter(converted.items())) + if target_name not in state_dict: + raise KeyError( + f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter." + ) + + state_dict[target_name] = state_dict[target_name] + delta_tensor.to( + dtype=state_dict[target_name].dtype, device=state_dict[target_name].device + ) + + def load_model( name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, - decoder_only=False, - custom_alignment_heads=None + decoder_only: bool = False, + custom_alignment_heads: Optional[str] = None, + lora_path: Optional[str] = None, ) -> Whisper: """ Load a Whisper ASR model @@ -255,6 +340,8 @@ def load_model( path to download the model files; by default, it uses "~/.cache/whisper" in_memory: bool whether to preload the model weights into host memory + lora_path: str + optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model) Returns ------- @@ -302,6 +389,7 @@ def load_model( else: state_dict = checkpoint state_dict = _convert_hf_state_dict(state_dict) + _apply_lora_adapter(state_dict, lora_path) if dims_cfg is not None: dims = ModelDimensions(**dims_cfg)