mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
LoRa path v0 - functional
This commit is contained in:
@@ -158,6 +158,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||
| `--lora-path` | Path or Hugging Face repo ID for LoRA adapter weights (e.g., `qfuxa/whisper-base-french-lora`). Only works with native Whisper backend (`--backend whisper`) | `None` |
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
|
||||
@@ -59,6 +59,7 @@ class TranscriptionEngine:
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
"lora_path": None,
|
||||
"lan": "auto",
|
||||
"direct_english_translation": False,
|
||||
}
|
||||
|
||||
@@ -16,9 +16,10 @@ class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.lora_path = lora_path
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
@@ -58,12 +59,12 @@ class WhisperASR(ASRBase):
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_whisper_model(str(resolved_path))
|
||||
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_whisper_model(model_size, download_root=cache_dir)
|
||||
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
options = dict(self.transcribe_kargs)
|
||||
|
||||
@@ -77,6 +77,7 @@ def backend_factory(
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
lora_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
@@ -138,6 +139,7 @@ def backend_factory(
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_override,
|
||||
lora_path=lora_path if backend_choice == "whisper" else None,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
@@ -106,6 +106,13 @@ def parse_args():
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="lora_path",
|
||||
help="Path or Hugging Face repo ID for LoRA adapter weights (e.g., QuentinFuxa/whisper-base-french-lora). Only works with native Whisper backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
|
||||
@@ -282,11 +282,13 @@ class SimulStreamingASR():
|
||||
|
||||
def load_model(self):
|
||||
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
|
||||
lora_path = getattr(self, 'lora_path', None)
|
||||
whisper_model = load_model(
|
||||
name=model_ref,
|
||||
download_root=None,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
custom_alignment_heads=self.custom_alignment_heads,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
|
||||
@@ -264,9 +264,49 @@ def _collapse_hf_module_name(module: str):
|
||||
return module
|
||||
|
||||
|
||||
def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Resolve LoRA adapter path - handles both local paths and HuggingFace repo IDs.
|
||||
|
||||
If lora_path is a local directory containing adapter files, returns it as-is.
|
||||
If lora_path looks like a HuggingFace repo ID (contains '/'), downloads and caches it.
|
||||
"""
|
||||
if not lora_path:
|
||||
return None
|
||||
|
||||
# Check if it's already a valid local path
|
||||
if os.path.isdir(lora_path):
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if os.path.isfile(config_path):
|
||||
return lora_path
|
||||
|
||||
# Try to download from HuggingFace Hub
|
||||
if "/" in lora_path:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
local_path = snapshot_download(
|
||||
repo_id=lora_path,
|
||||
allow_patterns=["adapter_config.json", "adapter_model.*"],
|
||||
)
|
||||
return local_path
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
|
||||
)
|
||||
|
||||
|
||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
# Resolve path (handles HuggingFace Hub download)
|
||||
lora_path = _resolve_lora_path(lora_path)
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
|
||||
Reference in New Issue
Block a user