From 82cd24bb75254b685614b399b0360c9999d35dc6 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 29 Nov 2025 17:21:10 +0100 Subject: [PATCH] LoRa path v0 - functional --- README.md | 1 + whisperlivekit/core.py | 1 + whisperlivekit/local_agreement/backends.py | 7 ++-- .../local_agreement/whisper_online.py | 2 + whisperlivekit/parse_args.py | 7 ++++ whisperlivekit/simul_whisper/backend.py | 4 +- whisperlivekit/whisper/__init__.py | 40 +++++++++++++++++++ 7 files changed, 58 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bd7e4a0..cb3ad94 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ async def websocket_endpoint(websocket: WebSocket): | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` | | `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` | +| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` | | Translation options | Description | Default | |-----------|-------------|---------| diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 88573d1..d510f63 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -59,6 +59,7 @@ class TranscriptionEngine: "model_cache_dir": None, "model_dir": None, "model_path": None, + "lora_path": None, "lan": "auto", "direct_english_translation": False, } diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index aadddc6..001e13e 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -16,9 +16,10 @@ class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when needed) - def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr): + def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr): self.logfile = logfile self.transcribe_kargs = {} + self.lora_path = lora_path if lan == "auto": self.original_language = None else: @@ -58,12 +59,12 @@ class WhisperASR(ASRBase): f"No supported PyTorch checkpoint found under {resolved_path}" ) logger.debug(f"Loading Whisper model from custom path {resolved_path}") - return load_whisper_model(str(resolved_path)) + return load_whisper_model(str(resolved_path), lora_path=self.lora_path) if model_size is None: raise ValueError("Either model_size or model_dir must be set for WhisperASR") - return load_whisper_model(model_size, download_root=cache_dir) + return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path) def transcribe(self, audio, init_prompt=""): options = dict(self.transcribe_kargs) diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index 4256dec..d74ac54 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -77,6 +77,7 @@ def backend_factory( model_cache_dir, model_dir, model_path, + lora_path, direct_english_translation, buffer_trimming, buffer_trimming_sec, @@ -138,6 +139,7 @@ def backend_factory( lan=lan, cache_dir=model_cache_dir, model_dir=model_override, + lora_path=lora_path if backend_choice == "whisper" else None, ) e = time.time() logger.info(f"done. It took {round(e-t,2)} seconds.") diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 7f67a97..9b5da4d 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -106,6 +106,13 @@ def parse_args(): default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.", ) + parser.add_argument( + "--lora-path", + type=str, + default=None, + dest="lora_path", + help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.", + ) parser.add_argument( "--lan", "--language", diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 74db4be..04a5ea6 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -282,11 +282,13 @@ class SimulStreamingASR(): def load_model(self): model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name + lora_path = getattr(self, 'lora_path', None) whisper_model = load_model( name=model_ref, download_root=None, decoder_only=self.fast_encoder, - custom_alignment_heads=self.custom_alignment_heads + custom_alignment_heads=self.custom_alignment_heads, + lora_path=lora_path, ) warmup_audio = load_file(self.warmup_file) if warmup_audio is not None: diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index c4996cd..ce68de9 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -264,9 +264,49 @@ def _collapse_hf_module_name(module: str): return module +def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]: + """ + Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs. + + If lora_path is a local directory containing adapter files, returns it as-is. + If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it. + """ + if not lora_path: + return None + + # Check if it's already a valid local path + if os.path.isdir(lora_path): + config_path = os.path.join(lora_path, "adapter_config.json") + if os.path.isfile(config_path): + return lora_path + + # Try to download from HuggingFace Hub + if "/" in lora_path: + try: + from huggingface_hub import snapshot_download + local_path = snapshot_download( + repo_id=lora_path, + allow_patterns=["adapter_config.json", "adapter_model.*"], + ) + return local_path + except Exception as e: + raise FileNotFoundError( + f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}" + ) + + raise FileNotFoundError( + f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID." + ) + + def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]): if not lora_path: return + + # Resolve path (handles HuggingFace Hub download) + lora_path = _resolve_lora_path(lora_path) + if not lora_path: + return config_path = os.path.join(lora_path, "adapter_config.json") if not os.path.isfile(config_path):