diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index e4cb79e..aadddc6 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -7,7 +7,7 @@ from typing import List import numpy as np import soundfile as sf -from whisperlivekit.model_paths import model_path_and_type, resolve_model_path +from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.timed_objects import ASRToken from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe @@ -47,24 +47,23 @@ class WhisperASR(ASRBase): sep = " " def load_model(self, model_size=None, cache_dir=None, model_dir=None): - from whisperlivekit.whisper import load_model as load_model + from whisperlivekit.whisper import load_model as load_whisper_model if model_dir is not None: - resolved_path = resolve_model_path(model_dir) + resolved_path = resolve_model_path(model_dir) if resolved_path.is_dir(): - pytorch_path, _, _ = model_path_and_type(resolved_path) - if pytorch_path is None: + model_info = detect_model_format(resolved_path) + if not model_info.has_pytorch: raise FileNotFoundError( f"No supported PyTorch checkpoint found under {resolved_path}" - ) - resolved_path = pytorch_path + ) logger.debug(f"Loading Whisper model from custom path {resolved_path}") - return load_model(str(resolved_path)) + return load_whisper_model(str(resolved_path)) if model_size is None: raise ValueError("Either model_size or model_dir must be set for WhisperASR") - return load_model(model_size, download_root=cache_dir) + return load_whisper_model(model_size, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): options = dict(self.transcribe_kargs) diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index 433fa16..4256dec 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -10,7 +10,7 @@ import numpy as np from whisperlivekit.backend_support import (faster_backend_available, mlx_backend_available) -from whisperlivekit.model_paths import model_path_and_type, resolve_model_path +from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.warmup import warmup_asr from .backends import FasterWhisperASR, MLXWhisper, OpenaiApiASR, WhisperASR @@ -87,16 +87,20 @@ def backend_factory( backend_choice = backend custom_reference = model_path or model_dir resolved_root = None - pytorch_checkpoint = None has_mlx_weights = False has_fw_weights = False + has_pytorch = False if custom_reference: resolved_root = resolve_model_path(custom_reference) if resolved_root.is_dir(): - pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root) + model_info = detect_model_format(resolved_root) + has_mlx_weights = model_info.compatible_whisper_mlx + has_fw_weights = model_info.compatible_faster_whisper + has_pytorch = model_info.has_pytorch else: - pytorch_checkpoint = resolved_root + # Single file provided + has_pytorch = True if backend_choice == "openai-api": logger.debug("Using OpenAI API.") @@ -121,8 +125,8 @@ def backend_factory( model_override = str(resolved_root) if resolved_root is not None else None else: asr_cls = WhisperASR - model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None - if custom_reference and model_override is None: + model_override = str(resolved_root) if resolved_root is not None else None + if custom_reference and not has_pytorch: raise FileNotFoundError( f"No PyTorch checkpoint found under {resolved_root or custom_reference}" ) diff --git a/whisperlivekit/model_paths.py b/whisperlivekit/model_paths.py index 2e3eb0a..2f2889d 100644 --- a/whisperlivekit/model_paths.py +++ b/whisperlivekit/model_paths.py @@ -1,49 +1,195 @@ +import json +import re +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union + + +@dataclass +class ModelInfo: + """Information about detected model format and files in a directory.""" + path: Optional[Path] = None + pytorch_files: List[Path] = field(default_factory=list) + compatible_whisper_mlx: bool = False + compatible_faster_whisper: bool = False + + @property + def has_pytorch(self) -> bool: + return len(self.pytorch_files) > 0 + + @property + def is_sharded(self) -> bool: + return len(self.pytorch_files) > 1 + + @property + def primary_pytorch_file(self) -> Optional[Path]: + """Return the primary PyTorch file (or first shard for sharded models).""" + if not self.pytorch_files: + return None + return self.pytorch_files[0] + + +#regex pattern for sharded model files such as: model-00001-of-00002.safetensors or pytorch_model-00001-of-00002.bin +SHARDED_PATTERN = re.compile(r"^(.+)-(\d{5})-of-(\d{5})\.(safetensors|bin)$") + +FASTER_WHISPER_MARKERS = {"model.bin", "encoder.bin", "decoder.bin"} +MLX_WHISPER_MARKERS = {"weights.npz", "weights.safetensors"} +CT2_INDICATOR_FILES = {"vocabulary.json", "vocabulary.txt", "shared_vocabulary.json"} + + +def _is_ct2_model_bin(directory: Path, filename: str) -> bool: + """ + Determine if model.bin/encoder.bin/decoder.bin is a CTranslate2 model. + + CTranslate2 models have specific companion files that distinguish them + from PyTorch .bin files. + """ + n_indicators = 0 + for indicator in CT2_INDICATOR_FILES: #test 1 + if (directory / indicator).exists(): + n_indicators += 1 + + if n_indicators == 0: + return False + + config_path = directory / "config.json" #test 2 + if config_path.exists(): + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + if config.get("model_type") == "whisper": #test 2 + return False + except (json.JSONDecodeError, IOError): + pass + + return True + + +def _collect_pytorch_files(directory: Path) -> List[Path]: + """ + Collect all PyTorch checkpoint files from a directory. + + Handles: + - Single files: model.safetensors, pytorch_model.bin, *.pt + - Sharded files: model-00001-of-00002.safetensors, pytorch_model-00001-of-00002.bin + - Index-based sharded models (reads index file to find shards) + + Returns files sorted appropriately (shards in order, or single file). + """ + for index_name in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]: + index_path = directory / index_name + if index_path.exists(): + try: + with open(index_path, "r", encoding="utf-8") as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + if weight_map: + shard_names = sorted(set(weight_map.values())) + shards = [directory / name for name in shard_names if (directory / name).exists()] + if shards: + return shards + except (json.JSONDecodeError, IOError): + pass + + sharded_groups = {} + single_files = {} + + for file in directory.iterdir(): + if not file.is_file(): + continue + + filename = file.name + suffix = file.suffix.lower() + + if filename.startswith("adapter_"): + continue + + match = SHARDED_PATTERN.match(filename) + if match: + base_name, shard_idx, total_shards, ext = match.groups() + key = (base_name, ext, int(total_shards)) + if key not in sharded_groups: + sharded_groups[key] = [] + sharded_groups[key].append((int(shard_idx), file)) + continue + + if filename == "model.safetensors": + single_files[0] = file # Highest priority + elif filename == "pytorch_model.bin": + single_files[1] = file + elif suffix == ".pt": + single_files[2] = file + elif suffix == ".safetensors" and not filename.startswith("adapter"): + single_files[3] = file + + for (base_name, ext, total_shards), shards in sharded_groups.items(): + if len(shards) == total_shards: + return [path for _, path in sorted(shards)] + + for priority in sorted(single_files.keys()): + return [single_files[priority]] + + return [] + + +def detect_model_format(model_path: Union[str, Path]) -> ModelInfo: + """ + Detect the model format in a given path. + + This function analyzes a file or directory to determine: + - What PyTorch checkpoint files are available (including sharded models) + - Whether the directory contains MLX Whisper weights + - Whether the directory contains Faster-Whisper (CTranslate2) weights + + Args: + model_path: Path to a model file or directory + + Returns: + ModelInfo with detected format information + """ + path = Path(model_path) + info = ModelInfo(path=path) + + if path.is_file(): + suffix = path.suffix.lower() + if suffix in {".pt", ".safetensors", ".bin"}: + info.pytorch_files = [path] + return info + + if not path.is_dir(): + return info + + for file in path.iterdir(): + if not file.is_file(): + continue + + filename = file.name.lower() + + if filename in MLX_WHISPER_MARKERS: + info.compatible_whisper_mlx = True + + if filename in FASTER_WHISPER_MARKERS: + if _is_ct2_model_bin(path, filename): + info.compatible_faster_whisper = True + + info.pytorch_files = _collect_pytorch_files(path) + + return info def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]: """ Inspect the provided path and determine which model formats are available. - + + This is a compatibility wrapper around detect_model_format(). + Returns: - pytorch_path: Path to a PyTorch checkpoint (if present). + pytorch_path: Path to a PyTorch checkpoint (first shard for sharded models, or None). compatible_whisper_mlx: True if MLX weights exist in this folder. - compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist. + compatible_faster_whisper: True if Faster-Whisper (CTranslate2) weights exist. """ - path = Path(model_path) - - compatible_whisper_mlx = False - compatible_faster_whisper = False - pytorch_path: Optional[Path] = None - - if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]: - pytorch_path = path - return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper - - if path.is_dir(): - for file in path.iterdir(): - if not file.is_file(): - continue - - filename = file.name.lower() - suffix = file.suffix.lower() - - if filename in {"weights.npz", "weights.safetensors"}: - compatible_whisper_mlx = True - elif filename in {"model.bin", "encoder.bin", "decoder.bin"}: - compatible_faster_whisper = True - elif suffix in {".pt", ".safetensors"}: - pytorch_path = file - elif filename == "pytorch_model.bin": - pytorch_path = file - - if pytorch_path is None: - fallback = path / "pytorch_model.bin" - if fallback.exists(): - pytorch_path = fallback - - return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper + info = detect_model_format(model_path) + return info.primary_pytorch_file, info.compatible_whisper_mlx, info.compatible_faster_whisper def resolve_model_path(model_path: Union[str, Path]) -> Path: @@ -59,7 +205,7 @@ def resolve_model_path(model_path: Union[str, Path]) -> Path: try: from huggingface_hub import snapshot_download - except ImportError as exc: # pragma: no cover - optional dependency guard + except ImportError as exc: raise FileNotFoundError( f"Model path '{model_path}' does not exist locally and huggingface_hub " "is not installed to download it." diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 7fdf09f..74db4be 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -11,7 +11,7 @@ import torch from whisperlivekit.backend_support import (faster_backend_available, mlx_backend_available) -from whisperlivekit.model_paths import model_path_and_type, resolve_model_path +from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.simul_whisper.config import AlignAttConfig from whisperlivekit.simul_whisper.simul_whisper import AlignAtt from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript @@ -159,34 +159,23 @@ class SimulStreamingASR(): self._resolved_model_path = None self.encoder_backend = "whisper" preferred_backend = getattr(self, "backend", "auto") - self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True + compatible_whisper_mlx, compatible_faster_whisper = True, True + if self.model_path: resolved_model_path = resolve_model_path(self.model_path) self._resolved_model_path = resolved_model_path self.model_path = str(resolved_model_path) - self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path) - if self.pytorch_path: - self.model_name = self.pytorch_path.stem - else: - self.model_name = Path(self.model_path).stem + + model_info = detect_model_format(resolved_model_path) + compatible_whisper_mlx = model_info.compatible_whisper_mlx + compatible_faster_whisper = model_info.compatible_faster_whisper + + if not model_info.has_pytorch: raise FileNotFoundError( f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" - ) + ) + self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem elif self.model_size is not None: - model_mapping = { - 'tiny': './tiny.pt', - 'base': './base.pt', - 'small': './small.pt', - 'medium': './medium.pt', - 'medium.en': './medium.en.pt', - 'large-v1': './large-v1.pt', - 'base.en': './base.en.pt', - 'small.en': './small.en.pt', - 'tiny.en': './tiny.en.pt', - 'large-v2': './large-v2.pt', - 'large-v3': './large-v3.pt', - 'large': './large-v3.pt' - } self.model_name = self.model_size else: raise ValueError("Either model_size or model_path must be specified for SimulStreaming.") @@ -292,9 +281,10 @@ class SimulStreamingASR(): return True def load_model(self): + model_ref = str(self._resolved_model_path) if self._resolved_model_path else self.model_name whisper_model = load_model( - name=self.pytorch_path if self.pytorch_path else self.model_name, - download_root=self.model_path, + name=model_ref, + download_root=None, decoder_only=self.fast_encoder, custom_alignment_heads=self.custom_alignment_heads ) diff --git a/whisperlivekit/whisper/__init__.py b/whisperlivekit/whisper/__init__.py index 1dd3504..c4996cd 100644 --- a/whisperlivekit/whisper/__init__.py +++ b/whisperlivekit/whisper/__init__.py @@ -319,6 +319,75 @@ def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]) ) +def _load_checkpoint( + file_path: Union[str, Path], + device: str, + in_memory: bool = False, + checkpoint_bytes: Optional[bytes] = None, +) -> Dict[str, torch.Tensor]: + """ + Load a checkpoint from a single file. + + Handles .pt, .bin, and .safetensors formats. + """ + if checkpoint_bytes is not None: + with io.BytesIO(checkpoint_bytes) as fp: + return torch.load(fp, map_location=device) + + file_path = Path(file_path) + suffix = file_path.suffix.lower() + + if suffix == '.safetensors': + try: + from safetensors.torch import load_file + except ImportError: + raise ImportError( + "Please install safetensors to load .safetensors model files: `pip install safetensors`" + ) + return load_file(str(file_path), device=device) + else: + if in_memory: + with open(file_path, "rb") as f: + checkpoint_bytes = f.read() + with io.BytesIO(checkpoint_bytes) as fp: + return torch.load(fp, map_location=device) + else: + with open(file_path, "rb") as fp: + return torch.load(fp, map_location=device) + + +def _load_sharded_checkpoint( + shard_files: List[Path], + device: str, +) -> Dict[str, torch.Tensor]: + """ + Load a sharded checkpoint (multiple .safetensors or .bin files). + + Merges all shards into a single state dict. + """ + merged_state_dict = {} + first_suffix = shard_files[0].suffix.lower() + + if first_suffix == '.safetensors': + try: + from safetensors.torch import load_file + except ImportError: + raise ImportError( + "Please install safetensors to load sharded .safetensors model: `pip install safetensors`" + ) + for shard_path in shard_files: + shard_dict = load_file(str(shard_path), device=device) + merged_state_dict.update(shard_dict) + else: + for shard_path in shard_files: + with open(shard_path, "rb") as fp: + shard_dict = torch.load(fp, map_location=device) + if isinstance(shard_dict, dict): + merged_state_dict.update(shard_dict) + + return merged_state_dict + + def load_model( name: str, device: Optional[Union[str, torch.device]] = None, @@ -336,6 +405,8 @@ def load_model( name : str one of the official model names listed by `whisper.available_models()`, or path to a model checkpoint containing the model dimensions and the model state_dict. + Can be a single file (.pt, .bin, .safetensors), a directory containing model files, + or a sharded model directory with files like model-00001-of-00002.safetensors. device : Union[str, torch.device] the PyTorch device to put the model into download_root: str @@ -350,16 +421,51 @@ def load_model( model : Whisper The Whisper ASR model instance """ + from whisperlivekit.model_paths import detect_model_format if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") + + checkpoint = None + model_path_for_config = name # Used to find config.json for dims inference + if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root, in_memory) + checkpoint_file = _download(_MODELS[name], download_root, in_memory) + if in_memory: + checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_file) + else: + checkpoint = _load_checkpoint(checkpoint_file, device) elif os.path.isfile(name): - checkpoint_file = open(name, "rb").read() if in_memory else name + if in_memory: + with open(name, "rb") as f: + checkpoint_bytes = f.read() + checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes) + else: + checkpoint = _load_checkpoint(name, device) + model_path_for_config = name + elif os.path.isdir(name): + model_info = detect_model_format(name) + + if not model_info.has_pytorch: + raise RuntimeError( + f"No PyTorch checkpoint found in directory {name}. " + f"Expected .pt, .bin, or .safetensors file(s)." + ) + + if model_info.is_sharded: + checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device) + else: + single_file = model_info.pytorch_files[0] + if in_memory: + with open(single_file, "rb") as f: + checkpoint_bytes = f.read() + checkpoint = _load_checkpoint(None, device, checkpoint_bytes=checkpoint_bytes) + else: + checkpoint = _load_checkpoint(single_file, device) + model_path_for_config = name else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}" @@ -369,22 +475,6 @@ def load_model( if custom_alignment_heads: alignment_heads = custom_alignment_heads.encode() - if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors': - try: - from safetensors.torch import load_file - except ImportError: - raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`") - if in_memory: - checkpoint = load_file(checkpoint_file, device=device) - else: - checkpoint = load_file(checkpoint_file, device=device) - else: - with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") - ) as fp: - checkpoint = torch.load(fp, map_location=device) - del checkpoint_file - dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] @@ -396,7 +486,7 @@ def load_model( if dims_cfg is not None: dims = ModelDimensions(**dims_cfg) else: - dims = _infer_dims_from_config(name) + dims = _infer_dims_from_config(model_path_for_config) if dims is None: raise RuntimeError( "Could not determine model dimensions. "