mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
val
This commit is contained in:
@@ -18,6 +18,8 @@ from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
from whisperlivekit.whisper.lora import (LoRAAdapter, LoRAAdapterManager,
|
||||
LoRAConfig, LoRALinear)
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
@@ -551,6 +553,94 @@ def load_model(
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def load_model_with_lora_manager(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only: bool = False,
|
||||
custom_alignment_heads: Optional[str] = None,
|
||||
adapters: Optional[Dict[str, str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Load a Whisper model with a LoRA adapter manager for dynamic adapter swapping.
|
||||
|
||||
This allows you to load multiple LoRA adapters and switch between them at runtime
|
||||
without keeping multiple full models in memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Model name or path (same as load_model)
|
||||
device : Union[str, torch.device]
|
||||
Device to load model on
|
||||
download_root : str
|
||||
Download directory for model files
|
||||
in_memory : bool
|
||||
Whether to preload model weights into host memory
|
||||
decoder_only : bool
|
||||
If True, only load the decoder (no encoder)
|
||||
custom_alignment_heads : str
|
||||
Custom alignment heads configuration
|
||||
adapters : Dict[str, str]
|
||||
Optional dict mapping adapter names to paths/HuggingFace repo IDs.
|
||||
Example: {"french": "path/to/french-lora", "spanish": "user/spanish-whisper-lora"}
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The base Whisper model (without any LoRA baked in)
|
||||
manager : LoRAAdapterManager
|
||||
The adapter manager for loading/switching adapters
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> model, manager = load_model_with_lora_manager(
|
||||
... "large-v3",
|
||||
... adapters={
|
||||
... "french": "path/to/french-lora",
|
||||
... "spanish": "path/to/spanish-lora"
|
||||
... }
|
||||
... )
|
||||
>>>
|
||||
>>> # Switch to French adapter
|
||||
>>> manager.set_adapter("french")
|
||||
>>> result_fr = model.transcribe(audio_fr)
|
||||
>>>
|
||||
>>> # Switch to Spanish adapter
|
||||
>>> manager.set_adapter("spanish")
|
||||
>>> result_es = model.transcribe(audio_es)
|
||||
>>>
|
||||
>>> # Use base model without LoRA
|
||||
>>> manager.set_adapter(None)
|
||||
>>> result_base = model.transcribe(audio)
|
||||
>>>
|
||||
>>> # Check memory usage
|
||||
>>> print(manager.get_memory_usage())
|
||||
{'french': 12.5, 'spanish': 12.5} # MB per adapter
|
||||
"""
|
||||
# Load the base model WITHOUT any LoRA baked in
|
||||
model = load_model(
|
||||
name=name,
|
||||
device=device,
|
||||
download_root=download_root,
|
||||
in_memory=in_memory,
|
||||
decoder_only=decoder_only,
|
||||
custom_alignment_heads=custom_alignment_heads,
|
||||
lora_path=None, # Important: no baked-in LoRA
|
||||
)
|
||||
|
||||
# Create the adapter manager
|
||||
manager = LoRAAdapterManager(model)
|
||||
|
||||
# Load any provided adapters
|
||||
if adapters:
|
||||
for adapter_name, adapter_path in adapters.items():
|
||||
manager.load_adapter(adapter_name, adapter_path)
|
||||
|
||||
return model, manager
|
||||
|
||||
|
||||
def convert_encoder_to_coreml(
|
||||
model_name = "base",
|
||||
output_path= "whisper_encoder.mlpackage",
|
||||
|
||||
473
whisperlivekit/whisper/lora.py
Normal file
473
whisperlivekit/whisper/lora.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""
|
||||
Dynamic LoRA adapter support for Whisper models.
|
||||
|
||||
This module enables loading a single base Whisper model and dynamically swapping
|
||||
between multiple LoRA adapters at runtime, saving GPU memory when working with
|
||||
multiple language-specific fine-tuned models.
|
||||
|
||||
Usage:
|
||||
from whisperlivekit.whisper import load_model
|
||||
from whisperlivekit.whisper.lora import LoRAAdapterManager
|
||||
|
||||
# Load base model without any LoRA baked in
|
||||
model = load_model("large-v3", device="cuda")
|
||||
|
||||
# Create adapter manager
|
||||
manager = LoRAAdapterManager(model)
|
||||
|
||||
# Load multiple adapters (small memory footprint each)
|
||||
manager.load_adapter("french", "path/to/french-lora")
|
||||
manager.load_adapter("spanish", "path/to/spanish-lora")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
manager.set_adapter("french")
|
||||
result_fr = model.transcribe(audio_fr)
|
||||
|
||||
manager.set_adapter("spanish")
|
||||
result_es = model.transcribe(audio_es)
|
||||
|
||||
# Disable LoRA (use base model only)
|
||||
manager.set_adapter(None)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .model import Linear
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for a LoRA adapter."""
|
||||
r: int # LoRA rank
|
||||
alpha: float # LoRA alpha (scaling factor)
|
||||
target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def scaling(self) -> float:
|
||||
return self.alpha / self.r
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAAdapter:
|
||||
"""Holds the LoRA A/B weight matrices for a single adapter."""
|
||||
name: str
|
||||
config: LoRAConfig
|
||||
# Maps target module name -> (A matrix, B matrix)
|
||||
weights: Dict[str, Tuple[Tensor, Tensor]] = field(default_factory=dict)
|
||||
device: torch.device = field(default_factory=lambda: torch.device("cpu"))
|
||||
dtype: torch.dtype = field(default=torch.float32)
|
||||
|
||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||
"""Move adapter weights to specified device/dtype."""
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self.weights = {
|
||||
name: (a.to(device=device, dtype=dtype or self.dtype),
|
||||
b.to(device=device, dtype=dtype or self.dtype))
|
||||
for name, (a, b) in self.weights.items()
|
||||
}
|
||||
return self
|
||||
|
||||
def memory_footprint_mb(self) -> float:
|
||||
"""Return approximate memory usage in MB."""
|
||||
total_bytes = 0
|
||||
for a, b in self.weights.values():
|
||||
total_bytes += a.numel() * a.element_size()
|
||||
total_bytes += b.numel() * b.element_size()
|
||||
return total_bytes / (1024 * 1024)
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
A Linear layer wrapper that supports dynamic LoRA injection.
|
||||
|
||||
The base weights remain unchanged. LoRA is applied additively during forward:
|
||||
output = base_linear(x) + (x @ A @ B) * scaling
|
||||
"""
|
||||
|
||||
def __init__(self, base_linear: Linear):
|
||||
super().__init__()
|
||||
self.base_linear = base_linear
|
||||
self.lora_A: Optional[Tensor] = None
|
||||
self.lora_B: Optional[Tensor] = None
|
||||
self.scaling: float = 1.0
|
||||
self._lora_enabled: bool = False
|
||||
|
||||
def set_lora(self, A: Optional[Tensor], B: Optional[Tensor], scaling: float = 1.0):
|
||||
"""Set the LoRA matrices for this layer."""
|
||||
self.lora_A = A
|
||||
self.lora_B = B
|
||||
self.scaling = scaling
|
||||
self._lora_enabled = A is not None and B is not None
|
||||
|
||||
def clear_lora(self):
|
||||
"""Remove LoRA from this layer."""
|
||||
self.lora_A = None
|
||||
self.lora_B = None
|
||||
self._lora_enabled = False
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Base linear output
|
||||
out = self.base_linear(x)
|
||||
|
||||
# Add LoRA contribution if enabled
|
||||
if self._lora_enabled and self.lora_A is not None and self.lora_B is not None:
|
||||
# x: (..., in_features)
|
||||
# A: (in_features, r)
|
||||
# B: (r, out_features)
|
||||
# lora_out: (..., out_features)
|
||||
lora_out = (x @ self.lora_A.to(x.dtype)) @ self.lora_B.to(x.dtype)
|
||||
out = out + lora_out * self.scaling
|
||||
|
||||
return out
|
||||
|
||||
# Delegate attribute access to base_linear for compatibility
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_linear.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.base_linear.bias
|
||||
|
||||
@property
|
||||
def in_features(self):
|
||||
return self.base_linear.in_features
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
return self.base_linear.out_features
|
||||
|
||||
|
||||
# Mapping from HuggingFace LoRA module names to Whisper module paths
|
||||
_HF_TO_WHISPER_MODULE_MAP = {
|
||||
# Encoder attention
|
||||
"model.encoder.layers.{}.self_attn.q_proj": "encoder.blocks.{}.attn.query",
|
||||
"model.encoder.layers.{}.self_attn.k_proj": "encoder.blocks.{}.attn.key",
|
||||
"model.encoder.layers.{}.self_attn.v_proj": "encoder.blocks.{}.attn.value",
|
||||
"model.encoder.layers.{}.self_attn.out_proj": "encoder.blocks.{}.attn.out",
|
||||
# Encoder MLP
|
||||
"model.encoder.layers.{}.fc1": "encoder.blocks.{}.mlp.0",
|
||||
"model.encoder.layers.{}.fc2": "encoder.blocks.{}.mlp.2",
|
||||
|
||||
# Decoder self-attention
|
||||
"model.decoder.layers.{}.self_attn.q_proj": "decoder.blocks.{}.attn.query",
|
||||
"model.decoder.layers.{}.self_attn.k_proj": "decoder.blocks.{}.attn.key",
|
||||
"model.decoder.layers.{}.self_attn.v_proj": "decoder.blocks.{}.attn.value",
|
||||
"model.decoder.layers.{}.self_attn.out_proj": "decoder.blocks.{}.attn.out",
|
||||
# Decoder cross-attention
|
||||
"model.decoder.layers.{}.encoder_attn.q_proj": "decoder.blocks.{}.cross_attn.query",
|
||||
"model.decoder.layers.{}.encoder_attn.k_proj": "decoder.blocks.{}.cross_attn.key",
|
||||
"model.decoder.layers.{}.encoder_attn.v_proj": "decoder.blocks.{}.cross_attn.value",
|
||||
"model.decoder.layers.{}.encoder_attn.out_proj": "decoder.blocks.{}.cross_attn.out",
|
||||
# Decoder MLP
|
||||
"model.decoder.layers.{}.fc1": "decoder.blocks.{}.mlp.0",
|
||||
"model.decoder.layers.{}.fc2": "decoder.blocks.{}.mlp.2",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_hf_module_name(name: str) -> str:
|
||||
"""Normalize HF-style LoRA module names."""
|
||||
if name.startswith("base_model."):
|
||||
name = name[len("base_model."):]
|
||||
if name.startswith("model.model."):
|
||||
name = name[len("model."):]
|
||||
if not name.startswith("model."):
|
||||
name = f"model.{name}"
|
||||
return name
|
||||
|
||||
|
||||
def _map_hf_to_whisper_module(hf_name: str) -> Optional[str]:
|
||||
"""Map a HuggingFace LoRA module name to Whisper module path."""
|
||||
hf_name = _normalize_hf_module_name(hf_name)
|
||||
|
||||
# Try to match with layer index patterns
|
||||
import re
|
||||
|
||||
# Match patterns like model.encoder.layers.5.self_attn.q_proj
|
||||
for pattern, target_pattern in _HF_TO_WHISPER_MODULE_MAP.items():
|
||||
# Create regex from pattern (replace {} with capture group)
|
||||
regex = pattern.replace(".", r"\.").replace("{}", r"(\d+)")
|
||||
match = re.fullmatch(regex, hf_name)
|
||||
if match:
|
||||
layer_idx = match.group(1)
|
||||
return target_pattern.format(layer_idx)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_module_by_path(model: nn.Module, path: str) -> Optional[nn.Module]:
|
||||
"""Get a submodule by dot-separated path."""
|
||||
parts = path.split(".")
|
||||
current = model
|
||||
for part in parts:
|
||||
if hasattr(current, part):
|
||||
current = getattr(current, part)
|
||||
elif hasattr(current, "__getitem__"):
|
||||
try:
|
||||
current = current[int(part)]
|
||||
except (ValueError, IndexError, KeyError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return current
|
||||
|
||||
|
||||
def _set_module_by_path(model: nn.Module, path: str, module: nn.Module):
|
||||
"""Set a submodule by dot-separated path."""
|
||||
parts = path.split(".")
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
if hasattr(parent, part):
|
||||
parent = getattr(parent, part)
|
||||
elif hasattr(parent, "__getitem__"):
|
||||
parent = parent[int(part)]
|
||||
setattr(parent, parts[-1], module)
|
||||
|
||||
|
||||
class LoRAAdapterManager:
|
||||
"""
|
||||
Manages multiple LoRA adapters for a Whisper model.
|
||||
|
||||
Enables loading multiple adapters and switching between them at runtime
|
||||
without reloading the full model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
"""
|
||||
Initialize the adapter manager.
|
||||
|
||||
Args:
|
||||
model: A Whisper model instance
|
||||
"""
|
||||
self.model = model
|
||||
self.adapters: Dict[str, LoRAAdapter] = {}
|
||||
self.current_adapter: Optional[str] = None
|
||||
self._lora_layers: Dict[str, LoRALinear] = {}
|
||||
self._original_layers: Dict[str, Linear] = {}
|
||||
self._initialized = False
|
||||
|
||||
def _initialize_lora_layers(self, target_modules: List[str]):
|
||||
"""
|
||||
Replace target Linear layers with LoRALinear wrappers.
|
||||
|
||||
This is done lazily on first adapter load.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Find and wrap all potential LoRA target modules
|
||||
for whisper_path in target_modules:
|
||||
module = _get_module_by_path(self.model, whisper_path)
|
||||
if module is None:
|
||||
continue
|
||||
if isinstance(module, Linear) and not isinstance(module, LoRALinear):
|
||||
# Wrap the Linear layer
|
||||
lora_linear = LoRALinear(module)
|
||||
_set_module_by_path(self.model, whisper_path, lora_linear)
|
||||
self._lora_layers[whisper_path] = lora_linear
|
||||
self._original_layers[whisper_path] = module
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _resolve_lora_path(self, lora_path: str) -> str:
|
||||
"""Resolve LoRA path, downloading from HuggingFace Hub if needed."""
|
||||
if os.path.isdir(lora_path):
|
||||
return lora_path
|
||||
|
||||
# Try HuggingFace Hub
|
||||
if "/" in lora_path:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(
|
||||
repo_id=lora_path,
|
||||
allow_patterns=["adapter_config.json", "adapter_model.*"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"LoRA path '{lora_path}' not found.")
|
||||
|
||||
def _load_adapter_weights(self, lora_path: str) -> Dict[str, Tensor]:
|
||||
"""Load adapter weights from safetensors or bin file."""
|
||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||
|
||||
if os.path.isfile(safe_path):
|
||||
from safetensors.torch import load_file
|
||||
return load_file(safe_path)
|
||||
elif os.path.isfile(bin_path):
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found in {lora_path}. "
|
||||
"Expected adapter_model.safetensors or adapter_model.bin."
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
name: str,
|
||||
lora_path: str,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> LoRAAdapter:
|
||||
"""
|
||||
Load a LoRA adapter from disk or HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
name: Unique name for this adapter (e.g., "french", "spanish")
|
||||
lora_path: Local path or HuggingFace repo ID
|
||||
device: Device to load weights to (default: model's device)
|
||||
dtype: Data type for weights (default: model's dtype)
|
||||
|
||||
Returns:
|
||||
The loaded LoRAAdapter
|
||||
"""
|
||||
if device is None:
|
||||
device = next(self.model.parameters()).device
|
||||
if dtype is None:
|
||||
dtype = next(self.model.parameters()).dtype
|
||||
|
||||
# Resolve path
|
||||
lora_path = self._resolve_lora_path(lora_path)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
raise FileNotFoundError(f"Missing adapter_config.json in {lora_path}")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
if config_dict.get("peft_type") != "LORA":
|
||||
raise ValueError("Only LoRA adapters are supported.")
|
||||
|
||||
config = LoRAConfig(
|
||||
r=config_dict["r"],
|
||||
alpha=config_dict.get("lora_alpha") or config_dict.get("alpha"),
|
||||
target_modules=config_dict.get("target_modules", []),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
adapter_state = self._load_adapter_weights(lora_path)
|
||||
|
||||
# Parse LoRA A/B matrices and map to Whisper module paths
|
||||
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||
for key, tensor in adapter_state.items():
|
||||
if key.endswith("lora_A.weight"):
|
||||
module = key[:-len(".lora_A.weight")]
|
||||
lora_layers.setdefault(module, {})["A"] = tensor
|
||||
elif key.endswith("lora_B.weight"):
|
||||
module = key[:-len(".lora_B.weight")]
|
||||
lora_layers.setdefault(module, {})["B"] = tensor
|
||||
|
||||
# Map to Whisper module paths and collect weights
|
||||
weights: Dict[str, Tuple[Tensor, Tensor]] = {}
|
||||
whisper_paths = set()
|
||||
|
||||
for hf_module, parts in lora_layers.items():
|
||||
if "A" not in parts or "B" not in parts:
|
||||
continue
|
||||
|
||||
whisper_path = _map_hf_to_whisper_module(hf_module)
|
||||
if whisper_path is None:
|
||||
# Try direct mapping (module might already be in Whisper format)
|
||||
whisper_path = hf_module
|
||||
|
||||
# A: (r, in_features) -> transpose to (in_features, r)
|
||||
# B: (out_features, r) -> transpose to (r, out_features)
|
||||
A = parts["A"].T # (in_features, r)
|
||||
B = parts["B"].T # (r, out_features)
|
||||
|
||||
weights[whisper_path] = (A, B)
|
||||
whisper_paths.add(whisper_path)
|
||||
|
||||
# Create adapter
|
||||
adapter = LoRAAdapter(
|
||||
name=name,
|
||||
config=config,
|
||||
weights=weights,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
adapter.to(device, dtype)
|
||||
|
||||
# Initialize LoRA layers if not done yet
|
||||
self._initialize_lora_layers(list(whisper_paths))
|
||||
|
||||
# Store adapter
|
||||
self.adapters[name] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
def set_adapter(self, name: Optional[str]):
|
||||
"""
|
||||
Switch to a different adapter or disable LoRA.
|
||||
|
||||
Args:
|
||||
name: Adapter name to activate, or None to disable all LoRA
|
||||
"""
|
||||
if name is not None and name not in self.adapters:
|
||||
raise KeyError(f"Adapter '{name}' not loaded. Available: {list(self.adapters.keys())}")
|
||||
|
||||
# Clear all LoRA from layers
|
||||
for lora_linear in self._lora_layers.values():
|
||||
lora_linear.clear_lora()
|
||||
|
||||
self.current_adapter = name
|
||||
|
||||
if name is None:
|
||||
return
|
||||
|
||||
# Apply the selected adapter
|
||||
adapter = self.adapters[name]
|
||||
for module_path, (A, B) in adapter.weights.items():
|
||||
if module_path in self._lora_layers:
|
||||
self._lora_layers[module_path].set_lora(A, B, adapter.config.scaling)
|
||||
|
||||
def unload_adapter(self, name: str):
|
||||
"""
|
||||
Unload an adapter from memory.
|
||||
|
||||
Args:
|
||||
name: Name of adapter to unload
|
||||
"""
|
||||
if name not in self.adapters:
|
||||
return
|
||||
|
||||
if self.current_adapter == name:
|
||||
self.set_adapter(None)
|
||||
|
||||
del self.adapters[name]
|
||||
|
||||
def list_adapters(self) -> List[str]:
|
||||
"""Return list of loaded adapter names."""
|
||||
return list(self.adapters.keys())
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""Return memory usage in MB for each loaded adapter."""
|
||||
return {name: adapter.memory_footprint_mb() for name, adapter in self.adapters.items()}
|
||||
|
||||
def restore_original_layers(self):
|
||||
"""
|
||||
Restore the original Linear layers, removing LoRA wrappers.
|
||||
|
||||
Call this if you want to go back to the original model structure.
|
||||
"""
|
||||
for path, original in self._original_layers.items():
|
||||
_set_module_by_path(self.model, path, original)
|
||||
|
||||
self._lora_layers.clear()
|
||||
self._original_layers.clear()
|
||||
self._initialized = False
|
||||
self.current_adapter = None
|
||||
|
||||
Reference in New Issue
Block a user