mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 06:14:05 +00:00
Fixes #294. improve model path backend detection and file extraction
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
|
||||
@@ -47,24 +47,23 @@ class WhisperASR(ASRBase):
|
||||
sep = " "
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from whisperlivekit.whisper import load_model as load_model
|
||||
from whisperlivekit.whisper import load_model as load_whisper_model
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||
if pytorch_path is None:
|
||||
model_info = detect_model_format(resolved_path)
|
||||
if not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
resolved_path = pytorch_path
|
||||
)
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_model(str(resolved_path))
|
||||
return load_whisper_model(str(resolved_path))
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_model(model_size, download_root=cache_dir)
|
||||
return load_whisper_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
options = dict(self.transcribe_kargs)
|
||||
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
|
||||
from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR
|
||||
@@ -87,16 +87,20 @@ def backend_factory(
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
pytorch_checkpoint = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
has_pytorch = False
|
||||
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||
model_info = detect_model_format(resolved_root)
|
||||
has_mlx_weights = model_info.compatible_whisper_mlx
|
||||
has_fw_weights = model_info.compatible_faster_whisper
|
||||
has_pytorch = model_info.has_pytorch
|
||||
else:
|
||||
pytorch_checkpoint = resolved_root
|
||||
# Single file provided
|
||||
has_pytorch = True
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
@@ -121,8 +125,8 @@ def backend_factory(
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||
if custom_reference and model_override is None:
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
if custom_reference and not has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
@@ -1,49 +1,195 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about detected model format and files in a directory."""
|
||||
path: Optional[Path] = None
|
||||
pytorch_files: List[Path] = field(default_factory=list)
|
||||
compatible_whisper_mlx: bool = False
|
||||
compatible_faster_whisper: bool = False
|
||||
|
||||
@property
|
||||
def has_pytorch(self) -> bool:
|
||||
return len(self.pytorch_files) > 0
|
||||
|
||||
@property
|
||||
def is_sharded(self) -> bool:
|
||||
return len(self.pytorch_files) > 1
|
||||
|
||||
@property
|
||||
def primary_pytorch_file(self) -> Optional[Path]:
|
||||
"""Return the primary PyTorch file (or first shard for sharded models)."""
|
||||
if not self.pytorch_files:
|
||||
return None
|
||||
return self.pytorch_files[0]
|
||||
|
||||
|
||||
#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin
|
||||
SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$")
|
||||
|
||||
FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"}
|
||||
MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"}
|
||||
CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"}
|
||||
|
||||
|
||||
def _is_ct2_model_bin(directory: Path, filename: str) -> bool:
|
||||
"""
|
||||
Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model.
|
||||
|
||||
CTranslate2 models have specific companion files that distinguish them
|
||||
from PyTorch .bin files.
|
||||
"""
|
||||
n_indicators = 0
|
||||
for indicator in CT2_INDICATOR_FILES: #test 1
|
||||
if (directory / indicator).exists():
|
||||
n_indicators += 1
|
||||
|
||||
if n_indicators == 0:
|
||||
return False
|
||||
|
||||
config_path = directory / "config.json" #test 2
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
if config.get("model_type") == "whisper": #test 2
|
||||
return False
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _collect_pytorch_files(directory: Path) -> List[Path]:
|
||||
"""
|
||||
Collect all PyTorch checkpoint files from a directory.
|
||||
|
||||
Handles:
|
||||
- Single files: model.safetensors, pytorch_model.bin, *.pt
|
||||
- Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin
|
||||
- Index-based sharded models (reads index file to find shards)
|
||||
|
||||
Returns files sorted appropriately (shards in order, or single file).
|
||||
"""
|
||||
for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
|
||||
index_path = directory / index_name
|
||||
if index_path.exists():
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
index_data = json.load(f)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
if weight_map:
|
||||
shard_names = sorted(set(weight_map.values()))
|
||||
shards = [directory / name for name in shard_names if (directory / name).exists()]
|
||||
if shards:
|
||||
return shards
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
sharded_groups = {}
|
||||
single_files = {}
|
||||
|
||||
for file in directory.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename.startswith("adapter_"):
|
||||
continue
|
||||
|
||||
match = SHARDED_PATTERN.match(filename)
|
||||
if match:
|
||||
base_name, shard_idx, total_shards, ext = match.groups()
|
||||
key = (base_name, ext, int(total_shards))
|
||||
if key not in sharded_groups:
|
||||
sharded_groups[key] = []
|
||||
sharded_groups[key].append((int(shard_idx), file))
|
||||
continue
|
||||
|
||||
if filename == "model.safetensors":
|
||||
single_files[0] = file # Highest priority
|
||||
elif filename == "pytorch_model.bin":
|
||||
single_files[1] = file
|
||||
elif suffix == ".pt":
|
||||
single_files[2] = file
|
||||
elif suffix == ".safetensors" and not filename.startswith("adapter"):
|
||||
single_files[3] = file
|
||||
|
||||
for (base_name, ext, total_shards), shards in sharded_groups.items():
|
||||
if len(shards) == total_shards:
|
||||
return [path for _, path in sorted(shards)]
|
||||
|
||||
for priority in sorted(single_files.keys()):
|
||||
return [single_files[priority]]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def detect_model_format(model_path: Union[str, Path]) -> ModelInfo:
|
||||
"""
|
||||
Detect the model format in a given path.
|
||||
|
||||
This function analyzes a file or directory to determine:
|
||||
- What PyTorch checkpoint files are available (including sharded models)
|
||||
- Whether the directory contains MLX Whisper weights
|
||||
- Whether the directory contains Faster-Whisper (CTranslate2) weights
|
||||
|
||||
Args:
|
||||
model_path: Path to a model file or directory
|
||||
|
||||
Returns:
|
||||
ModelInfo with detected format information
|
||||
"""
|
||||
path = Path(model_path)
|
||||
info = ModelInfo(path=path)
|
||||
|
||||
if path.is_file():
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in {".pt", ".safetensors", ".bin"}:
|
||||
info.pytorch_files = [path]
|
||||
return info
|
||||
|
||||
if not path.is_dir():
|
||||
return info
|
||||
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
|
||||
if filename in MLX_WHISPER_MARKERS:
|
||||
info.compatible_whisper_mlx = True
|
||||
|
||||
if filename in FASTER_WHISPER_MARKERS:
|
||||
if _is_ct2_model_bin(path, filename):
|
||||
info.compatible_faster_whisper = True
|
||||
|
||||
info.pytorch_files = _collect_pytorch_files(path)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
|
||||
This is a compatibility wrapper around detect_model_format().
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||
pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||
compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path: Optional[Path] = None
|
||||
|
||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||
pytorch_path = path
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename in {"weights.npz", "weights.safetensors"}:
|
||||
compatible_whisper_mlx = True
|
||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||
compatible_faster_whisper = True
|
||||
elif suffix in {".pt", ".safetensors"}:
|
||||
pytorch_path = file
|
||||
elif filename == "pytorch_model.bin":
|
||||
pytorch_path = file
|
||||
|
||||
if pytorch_path is None:
|
||||
fallback = path / "pytorch_model.bin"
|
||||
if fallback.exists():
|
||||
pytorch_path = fallback
|
||||
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
info = detect_model_format(model_path)
|
||||
return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
@@ -59,7 +205,7 @@ def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
except ImportError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
@@ -159,34 +159,23 @@ class SimulStreamingASR():
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
if self.pytorch_path:
|
||||
self.model_name = self.pytorch_path.stem
|
||||
else:
|
||||
self.model_name = Path(self.model_path).stem
|
||||
|
||||
model_info = detect_model_format(resolved_model_path)
|
||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||
|
||||
if not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
)
|
||||
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
'small': './small.pt',
|
||||
'medium': './medium.pt',
|
||||
'medium.en': './medium.en.pt',
|
||||
'large-v1': './large-v1.pt',
|
||||
'base.en': './base.en.pt',
|
||||
'small.en': './small.en.pt',
|
||||
'tiny.en': './tiny.en.pt',
|
||||
'large-v2': './large-v2.pt',
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_name = self.model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
@@ -292,9 +281,10 @@ class SimulStreamingASR():
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name
|
||||
whisper_model = load_model(
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
name=model_ref,
|
||||
download_root=None,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
)
|
||||
|
||||
@@ -319,6 +319,75 @@ def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str])
|
||||
)
|
||||
|
||||
|
||||
def _load_checkpoint(
|
||||
file_path: Union[str, Path],
|
||||
device: str,
|
||||
in_memory: bool = False,
|
||||
checkpoint_bytes: Optional[bytes] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load a checkpoint from a single file.
|
||||
|
||||
Handles .pt, .bin, and .safetensors formats.
|
||||
"""
|
||||
if checkpoint_bytes is not None:
|
||||
with io.BytesIO(checkpoint_bytes) as fp:
|
||||
return torch.load(fp, map_location=device)
|
||||
|
||||
file_path = Path(file_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
if suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install safetensors to load .safetensors model files: `pip install safetensors`"
|
||||
)
|
||||
return load_file(str(file_path), device=device)
|
||||
else:
|
||||
if in_memory:
|
||||
with open(file_path, "rb") as f:
|
||||
checkpoint_bytes = f.read()
|
||||
with io.BytesIO(checkpoint_bytes) as fp:
|
||||
return torch.load(fp, map_location=device)
|
||||
else:
|
||||
with open(file_path, "rb") as fp:
|
||||
return torch.load(fp, map_location=device)
|
||||
|
||||
|
||||
def _load_sharded_checkpoint(
|
||||
shard_files: List[Path],
|
||||
device: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load a sharded checkpoint (multiple .safetensors or .bin files).
|
||||
|
||||
Merges all shards into a single state dict.
|
||||
"""
|
||||
merged_state_dict = {}
|
||||
first_suffix = shard_files[0].suffix.lower()
|
||||
|
||||
if first_suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install safetensors to load sharded .safetensors model: `pip install safetensors`"
|
||||
)
|
||||
for shard_path in shard_files:
|
||||
shard_dict = load_file(str(shard_path), device=device)
|
||||
merged_state_dict.update(shard_dict)
|
||||
else:
|
||||
for shard_path in shard_files:
|
||||
with open(shard_path, "rb") as fp:
|
||||
shard_dict = torch.load(fp, map_location=device)
|
||||
if isinstance(shard_dict, dict):
|
||||
merged_state_dict.update(shard_dict)
|
||||
|
||||
return merged_state_dict
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
@@ -336,6 +405,8 @@ def load_model(
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
Can be a single file (.pt, .bin, .safetensors), a directory containing model files,
|
||||
or a sharded model directory with files like model-00001-of-00002.safetensors.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
@@ -350,16 +421,51 @@ def load_model(
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
from whisperlivekit.model_paths import detect_model_format
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
checkpoint = None
|
||||
model_path_for_config = name # Used to find config.json for dims inference
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
if in_memory:
|
||||
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file)
|
||||
else:
|
||||
checkpoint = _load_checkpoint(checkpoint_file, device)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
if in_memory:
|
||||
with open(name, "rb") as f:
|
||||
checkpoint_bytes = f.read()
|
||||
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
|
||||
else:
|
||||
checkpoint = _load_checkpoint(name, device)
|
||||
model_path_for_config = name
|
||||
elif os.path.isdir(name):
|
||||
model_info = detect_model_format(name)
|
||||
|
||||
if not model_info.has_pytorch:
|
||||
raise RuntimeError(
|
||||
f"No PyTorch checkpoint found in directory {name}. "
|
||||
f"Expected .pt, .bin, or .safetensors file(s)."
|
||||
)
|
||||
|
||||
if model_info.is_sharded:
|
||||
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
|
||||
else:
|
||||
single_file = model_info.pytorch_files[0]
|
||||
if in_memory:
|
||||
with open(single_file, "rb") as f:
|
||||
checkpoint_bytes = f.read()
|
||||
checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes)
|
||||
else:
|
||||
checkpoint = _load_checkpoint(single_file, device)
|
||||
model_path_for_config = name
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
@@ -369,22 +475,6 @@ def load_model(
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||
if in_memory:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
@@ -396,7 +486,7 @@ def load_model(
|
||||
if dims_cfg is not None:
|
||||
dims = ModelDimensions(**dims_cfg)
|
||||
else:
|
||||
dims = _infer_dims_from_config(name)
|
||||
dims = _infer_dims_from_config(model_path_for_config)
|
||||
if dims is None:
|
||||
raise RuntimeError(
|
||||
"Could not determine model dimensions. "
|
||||
|
||||
Reference in New Issue
Block a user