diff --git a/docs/models_compatible_formats.md b/docs/models_compatible_formats.md index fb34d8e..e6a760b 100644 --- a/docs/models_compatible_formats.md +++ b/docs/models_compatible_formats.md @@ -3,12 +3,14 @@ The `--model-path` parameter accepts: ## File Path -- **`.pt` format only** (required for AlignAtt policy decoder) +- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor. ## Directory Path (recommended) Must contain: -- **`.pt` file** (required for decoder) +- **`.pt` / `.bin` / `.safetensor` file** (required for decoder) May optionally contain: - **`.bin` file** - faster-whisper model for encoder (requires faster-whisper) -- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx) \ No newline at end of file +- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx) + +To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder. \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 26bce5f..04b4705 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -41,10 +41,10 @@ def model_path_and_type(model_path): compatible_whisper_mlx = False compatible_faster_whisper = False - pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None - if pt_path is None: - pt_path = path if path.is_file() and path.suffix.lower() == '.bin' else None - if path.is_dir(): + pytorch_path = None + if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']: + pytorch_path = path + elif path.is_dir(): for file in path.iterdir(): if file.is_file(): if file.name in ['weights.npz', "weights.safetensors"]: @@ -52,11 +52,13 @@ def model_path_and_type(model_path): elif file.suffix.lower() == '.bin': compatible_faster_whisper = True elif file.suffix.lower() == '.pt': - pt_path = file - if pt_path is None: + pytorch_path = file + elif file.suffix.lower() == '.safetensors': + pytorch_path = file + if pytorch_path is None: if (model_path / Path("pytorch_model.bin")).exists(): - pt_path = model_path / Path("pytorch_model.bin") - return pt_path, compatible_whisper_mlx, compatible_faster_whisper + pytorch_path = model_path / Path("pytorch_model.bin") + return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper class SimulStreamingOnlineProcessor: @@ -175,10 +177,10 @@ class SimulStreamingASR(): self.decoder_type = 'greedy' if self.beams == 1 else 'beam' self.fast_encoder = False - self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True + self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True if self.model_path: - self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path) - self.model_name = self.pt_path.stem + self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path) + self.model_name = self.pytorch_path.stem is_multilingual = not self.model_path.endswith(".en") elif self.model_size is not None: model_mapping = { @@ -231,9 +233,10 @@ class SimulStreamingASR(): if self.model_path and compatible_whisper_mlx: mlx_model = self.model_path else: - mlx_model = mlx_model_mapping[self.model_name] - self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) - self.fast_encoder = True + mlx_model = mlx_model_mapping.get(self.model_name) + if mlx_model: + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) + self.fast_encoder = True elif HAS_FASTER_WHISPER and compatible_faster_whisper: print('Simulstreaming will use Faster Whisper for the encoder.') if self.model_path and compatible_faster_whisper: @@ -252,7 +255,7 @@ class SimulStreamingASR(): def load_model(self): whisper_model = load_model( - name=self.pt_path if self.pt_path else self.model_name, + name=self.pytorch_path if self.pytorch_path else self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder, custom_alignment_heads=self.custom_alignment_heads diff --git a/whisperlivekit/simul_whisper/whisper/__init__.py b/whisperlivekit/simul_whisper/whisper/__init__.py index 7664166..751020a 100644 --- a/whisperlivekit/simul_whisper/whisper/__init__.py +++ b/whisperlivekit/simul_whisper/whisper/__init__.py @@ -8,6 +8,7 @@ from typing import List, Optional, Union, Dict import torch from tqdm import tqdm +from pathlib import Path from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language @@ -279,10 +280,20 @@ def load_model( if custom_alignment_heads: alignment_heads = custom_alignment_heads.encode() - with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") - ) as fp: - checkpoint = torch.load(fp, map_location=device) + if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors': + try: + from safetensors.torch import load_file + except ImportError: + raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`") + if in_memory: + checkpoint = load_file(checkpoint_file, device=device) + else: + checkpoint = load_file(checkpoint_file, device=device) + else: + with ( + io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") + ) as fp: + checkpoint = torch.load(fp, map_location=device) del checkpoint_file dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None