Quentin Fuxa
2025-11-09 20:08:18 +01:00
parent a732e0903e
commit 7108d2ddc5
3 changed files with 38 additions and 22 deletions

View File

@@ -3,12 +3,14 @@
The `--model-path` parameter accepts:
## File Path
- **`.pt` format only** (required for AlignAtt policy decoder)
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` file** (required for decoder)
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

View File

@@ -41,10 +41,10 @@ def model_path_and_type(model_path):
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if pt_path is None:
pt_path = path if path.is_file() and path.suffix.lower() == '.bin' else None
if path.is_dir():
pytorch_path = None
if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']:
pytorch_path = path
elif path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
@@ -52,11 +52,13 @@ def model_path_and_type(model_path):
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
if pt_path is None:
pytorch_path = file
elif file.suffix.lower() == '.safetensors':
pytorch_path = file
if pytorch_path is None:
if (model_path / Path("pytorch_model.bin")).exists():
pt_path = model_path / Path("pytorch_model.bin")
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
pytorch_path = model_path / Path("pytorch_model.bin")
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
class SimulStreamingOnlineProcessor:
@@ -175,10 +177,10 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.model_name = self.pt_path.stem
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.model_name = self.pytorch_path.stem
is_multilingual = not self.model_path.endswith(".en")
elif self.model_size is not None:
model_mapping = {
@@ -231,9 +233,10 @@ class SimulStreamingASR():
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
mlx_model = mlx_model_mapping.get(self.model_name)
if mlx_model:
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
@@ -252,7 +255,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(
name=self.pt_path if self.pt_path else self.model_name,
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads

View File

@@ -8,6 +8,7 @@ from typing import List, Optional, Union, Dict
import torch
from tqdm import tqdm
from pathlib import Path
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
@@ -279,10 +280,20 @@ def load_model(
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None