mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
This commit is contained in:
@@ -3,12 +3,14 @@
|
||||
The `--model-path` parameter accepts:
|
||||
|
||||
## File Path
|
||||
- **`.pt` format only** (required for AlignAtt policy decoder)
|
||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||
|
||||
## Directory Path (recommended)
|
||||
Must contain:
|
||||
- **`.pt` file** (required for decoder)
|
||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||
|
||||
May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
@@ -41,10 +41,10 @@ def model_path_and_type(model_path):
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
|
||||
if pt_path is None:
|
||||
pt_path = path if path.is_file() and path.suffix.lower() == '.bin' else None
|
||||
if path.is_dir():
|
||||
pytorch_path = None
|
||||
if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']:
|
||||
pytorch_path = path
|
||||
elif path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if file.is_file():
|
||||
if file.name in ['weights.npz', "weights.safetensors"]:
|
||||
@@ -52,11 +52,13 @@ def model_path_and_type(model_path):
|
||||
elif file.suffix.lower() == '.bin':
|
||||
compatible_faster_whisper = True
|
||||
elif file.suffix.lower() == '.pt':
|
||||
pt_path = file
|
||||
if pt_path is None:
|
||||
pytorch_path = file
|
||||
elif file.suffix.lower() == '.safetensors':
|
||||
pytorch_path = file
|
||||
if pytorch_path is None:
|
||||
if (model_path / Path("pytorch_model.bin")).exists():
|
||||
pt_path = model_path / Path("pytorch_model.bin")
|
||||
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
pytorch_path = model_path / Path("pytorch_model.bin")
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
@@ -175,10 +177,10 @@ class SimulStreamingASR():
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
self.pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
|
||||
self.model_name = self.pt_path.stem
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
|
||||
self.model_name = self.pytorch_path.stem
|
||||
is_multilingual = not self.model_path.endswith(".en")
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
@@ -231,9 +233,10 @@ class SimulStreamingASR():
|
||||
if self.model_path and compatible_whisper_mlx:
|
||||
mlx_model = self.model_path
|
||||
else:
|
||||
mlx_model = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
self.fast_encoder = True
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if mlx_model:
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self.model_path and compatible_faster_whisper:
|
||||
@@ -252,7 +255,7 @@ class SimulStreamingASR():
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(
|
||||
name=self.pt_path if self.pt_path else self.model_name,
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional, Union, Dict
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
@@ -279,10 +280,20 @@ def load_model(
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||
if in_memory:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||
|
||||
Reference in New Issue
Block a user