5 Commits

Author SHA1 Message Date
Quentin Fuxa
a246ba9bfe v0 2025-11-09 22:02:15 +01:00
Quentin Fuxa
7108d2ddc5 fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/269 2025-11-09 20:08:18 +01:00
Quentin Fuxa
a732e0903e Add a script to detect alignement heads, usefull for distilled whisper 2025-11-09 18:12:09 +01:00
Quentin Fuxa
0491681be4 Distilled model compatibility with HF config.json to ModelDimensions 2025-11-08 20:20:05 +01:00
Quentin Fuxa
ffe5284764 _processing_tasks_done checks task completion 2025-11-05 23:34:00 +01:00
13 changed files with 520 additions and 43 deletions

View File

@@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
| `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |

View File

@@ -3,12 +3,14 @@
The `--model-path` parameter accepts:
## File Path
- **`.pt` format only** (required for AlignAtt policy decoder)
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` file** (required for decoder)
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.13"
version = "0.2.13.post1"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -30,7 +30,6 @@ dependencies = [
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",

BIN
scripts/alignment_heads.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -0,0 +1,292 @@
"""Determine alignment heads for a variants, such as distilled model"""
from __future__ import annotations
import argparse
import base64
import gzip
import io
import pathlib
import sys
import math
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from datasets import Audio as DatasetAudio, load_dataset
import soundfile as sf
import matplotlib.pyplot as plt
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(WHISPER_ROOT))
from whisper import load_model
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
def load_dataset_clips(name, config, split, limit):
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
waveform = waveform_np
transcript = str(transcript)
clips.append((waveform, transcript))
return clips
def load_clips(args):
return load_dataset_clips(
args.dataset,
args.dataset_config,
args.dataset_split,
args.dataset_num_samples,
)
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
return waveform
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="pytorch_model.bin",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device to run on",
)
parser.add_argument(
"--dataset",
type=str,
default="librispeech_asr"
)
parser.add_argument(
"--dataset-config",
type=str,
default="clean"
)
parser.add_argument(
"--dataset-split",
type=str,
default="validation[:1%]",
)
parser.add_argument(
"--dataset-num-samples",
type=int,
default=16,
)
parser.add_argument(
"--threshold",
type=float,
default=1.5,
help="Z score threshold for a head to be selected",
)
parser.add_argument(
"--votes",
type=float,
default=0.75,
help="percentage of clips that must vote for a head",
)
parser.add_argument(
"--output",
type=str,
default="alignment_heads.b85",
)
parser.add_argument(
"--visualize-top-k",
type=int,
default=32,
)
return parser.parse_args()
def collect_heads(
model,
tokenizer,
clips: Sequence[Tuple[AudioInput, str]],
threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
strengths = torch.zeros_like(votes)
for audio_source, transcript in clips:
waveform = pad_or_trim(_waveform_from_source(audio_source))
mel = log_mel_spectrogram(waveform, device=device)
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=device,
)
qks = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
for layer_idx, tensor in enumerate(qks):
if tensor is None:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1)
peak = tensor.max(dim=-1).values # [heads, tokens]
strengths[layer_idx] += peak.mean(dim=-1)
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
)
mask = (zscore > 3).any(dim=-1)
votes[layer_idx] += mask.float()
votes /= len(clips)
strengths /= len(clips)
return votes, strengths
def _select_heads_for_visualization(selection, strengths, top_k):
selected = torch.nonzero(selection, as_tuple=False)
if selected.numel() == 0:
return []
entries = [
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
for layer, head in selected
]
entries.sort(key=lambda item: item[2], reverse=True)
return entries[:top_k]
def _extract_heatmaps(
model,
tokenizer,
clip: Tuple[AudioInput, str],
heads: Sequence[Tuple[int, int, float]],
) -> dict:
if not heads:
return {}
target_map = {}
for layer, head, _ in heads:
target_map.setdefault(layer, set()).add(head)
waveform = pad_or_trim(_waveform_from_source(clip[0]))
mel = log_mel_spectrogram(waveform, device=model.device)
transcript = clip[1]
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=model.device,
)
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
heatmaps = {}
for layer_idx, tensor in enumerate(QKs):
if tensor is None or layer_idx not in target_map:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1).cpu()
for head_idx in target_map[layer_idx]:
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
return heatmaps
def _plot_heatmaps(
heads, heatmaps, output_path):
cols = min(3, len(heads))
rows = math.ceil(len(heads) / cols)
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
for idx, (layer, head, score) in enumerate(heads):
ax = axes[idx // cols][idx % cols]
mat = heatmaps.get((layer, head))
if mat is None:
ax.axis("off")
continue
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
ax.set_xlabel("time")
ax.set_ylabel("tokens")
for j in range(len(heads), rows * cols):
axes[j // cols][j % cols].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=200)
plt.close(fig)
def _dump_mask(mask: torch.Tensor, output_path: str):
payload = mask.numpy().astype(np.bool_)
blob = base64.b85encode(gzip.compress(payload.tobytes()))
with open(output_path, "wb") as f:
f.write(blob)
def main():
args = _parse_args()
model = load_model(args.model, device=args.device)
model.eval()
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
clips = load_clips(args)
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
# selection = votes > 0.5
selection = strengths > 0.05
_dump_mask(selection.cpu(), args.output)
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,10 @@
"""Copy core files from web directory to Chrome extension directory."""
import shutil
import os
from pathlib import Path
def sync_extension_files():
"""Copy core files from web directory to Chrome extension directory."""
web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension")

View File

@@ -464,7 +464,7 @@ class AudioProcessor:
yield response
self.last_response_content = response
if self.is_stopping and self.transcription_task and self.transcription_task.done() and self.diarization_task and self.diarization_task.done():
if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
@@ -517,11 +517,16 @@ class AudioProcessor:
async def watchdog(self, tasks_to_monitor):
"""Monitors the health of critical processing tasks."""
tasks_remaining = [task for task in tasks_to_monitor if task]
while True:
try:
if not tasks_remaining:
logger.info("Watchdog task finishing: all monitored tasks completed.")
return
await asyncio.sleep(10)
for i, task in enumerate(tasks_to_monitor):
for i, task in enumerate(list(tasks_remaining)):
if task.done():
exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
@@ -529,6 +534,7 @@ class AudioProcessor:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else:
logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
@@ -559,6 +565,16 @@ class AudioProcessor:
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self):
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
self.diarization_task,
self.translation_task,
self.ffmpeg_reader_task,
]
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message):
"""Process incoming audio data."""

View File

@@ -179,5 +179,4 @@ def online_translation_factory(args, translation_model):
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from nllw import OnlineTranslation
from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -103,5 +103,4 @@ def handle_silences(tokens, beg_loop, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens
return tokens

View File

@@ -1,4 +1,3 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, Segment, format_time
@@ -164,4 +163,4 @@ def format_output(state, silence, args, sep):
translation=segment.translation
))
return lines, undiarized_text
return lines, undiarized_text

View File

@@ -23,7 +23,7 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: `pip install mlx-whisper`\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@@ -32,6 +32,8 @@ else:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
if platform.system() != "Darwin":
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""")
HAS_FASTER_WHISPER = False
def model_path_and_type(model_path):
@@ -39,9 +41,10 @@ def model_path_and_type(model_path):
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if path.is_dir():
pytorch_path = None
if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']:
pytorch_path = path
elif path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
@@ -49,8 +52,13 @@ def model_path_and_type(model_path):
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
pytorch_path = file
elif file.suffix.lower() == '.safetensors':
pytorch_path = file
if pytorch_path is None:
if (model_path / Path("pytorch_model.bin")).exists():
pytorch_path = model_path / Path("pytorch_model.bin")
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
class SimulStreamingOnlineProcessor:
@@ -169,11 +177,11 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.model_name = self.pytorch_path.stem
is_multilingual = not self.model_path.endswith(".en")
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
@@ -189,12 +197,11 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
self.model_name = self.model_size
is_multilingual = not self.model_name.endswith(".en")
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold,
language=self.lan,
@@ -226,9 +233,10 @@ class SimulStreamingASR():
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
mlx_model = mlx_model_mapping.get(self.model_name)
if mlx_model:
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
@@ -247,7 +255,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(
name=self.model_path if self.model_path else self.model_name,
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads

View File

@@ -1,12 +1,14 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict
import torch
from tqdm import tqdm
from pathlib import Path
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
@@ -100,6 +102,137 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
@@ -134,7 +267,6 @@ def load_model(
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
@@ -148,22 +280,50 @@ def load_model(
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
state_dict = _convert_hf_state_dict(state_dict)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

View File

@@ -287,3 +287,4 @@ class State():
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None