mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
|
|
|
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
|
encoder positional embeddings and updating the stored model dimensions.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple
|
|
|
|
import torch
|
|
|
|
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
|
from whisperlivekit.whisper.model import ModelDimensions
|
|
from whisperlivekit.whisper.utils import exact_div
|
|
from whisperlivekit.whisper import _convert_hf_state_dict
|
|
|
|
|
|
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
|
safetensor_path = repo_path / "model.safetensors"
|
|
bin_path = repo_path / "pytorch_model.bin"
|
|
|
|
if safetensor_path.is_file():
|
|
try:
|
|
from safetensors.torch import load_file # type: ignore
|
|
except Exception as exc: # pragma: no cover - import guard
|
|
raise RuntimeError(
|
|
"Install safetensors to load model.safetensors "
|
|
"(pip install safetensors)"
|
|
) from exc
|
|
return load_file(str(safetensor_path))
|
|
|
|
if bin_path.is_file():
|
|
return torch.load(bin_path, map_location="cpu")
|
|
|
|
raise FileNotFoundError(
|
|
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
|
)
|
|
|
|
|
|
def _load_config(repo_path: Path) -> Dict:
|
|
config_path = repo_path / "config.json"
|
|
if not config_path.is_file():
|
|
raise FileNotFoundError(
|
|
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
|
)
|
|
with open(config_path, "r", encoding="utf-8") as fp:
|
|
return json.load(fp)
|
|
|
|
|
|
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
|
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
|
expected_samples = chunk_length * SAMPLE_RATE
|
|
if abs(n_samples - expected_samples) > 1e-6:
|
|
raise ValueError(
|
|
"chunk_length must align with sample rate so that "
|
|
"chunk_length * SAMPLE_RATE is an integer"
|
|
)
|
|
n_frames = exact_div(n_samples, HOP_LENGTH)
|
|
n_audio_ctx = exact_div(n_frames, 2)
|
|
return n_frames, n_audio_ctx
|
|
|
|
|
|
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
|
base_dims = ModelDimensions(
|
|
n_mels=config["num_mel_bins"],
|
|
n_audio_ctx=config["max_source_positions"],
|
|
n_audio_state=config["d_model"],
|
|
n_audio_head=config["encoder_attention_heads"],
|
|
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
|
n_vocab=config["vocab_size"],
|
|
n_text_ctx=config["max_target_positions"],
|
|
n_text_state=config["d_model"],
|
|
n_text_head=config["decoder_attention_heads"],
|
|
n_text_layer=config["decoder_layers"],
|
|
).__dict__.copy()
|
|
|
|
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
|
base_dims["n_audio_ctx"] = n_audio_ctx
|
|
base_dims["chunk_length"] = chunk_length
|
|
return base_dims
|
|
|
|
|
|
def _trim_positional_embedding(
|
|
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
|
) -> None:
|
|
key = "encoder.positional_embedding"
|
|
if key not in state_dict:
|
|
raise KeyError(f"{key} missing from converted state dict")
|
|
|
|
tensor = state_dict[key]
|
|
if tensor.shape[0] < target_ctx:
|
|
raise ValueError(
|
|
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
|
)
|
|
if tensor.shape[0] == target_ctx:
|
|
return
|
|
state_dict[key] = tensor[:target_ctx].contiguous()
|
|
|
|
|
|
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
|
state_dict = _load_state_dict(hf_path)
|
|
converted = _convert_hf_state_dict(state_dict)
|
|
|
|
config = _load_config(hf_path)
|
|
dims = _build_dims(config, chunk_length)
|
|
|
|
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
|
|
|
package = {"dims": dims, "model_state_dict": converted}
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(package, output_path)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
|
)
|
|
parser.add_argument(
|
|
"hf_path",
|
|
type=str,
|
|
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="converted-whisper.pt",
|
|
help="Destination path for the .pt file",
|
|
)
|
|
parser.add_argument(
|
|
"--chunk-length",
|
|
type=float,
|
|
default=30.0,
|
|
help="Audio chunk length in seconds to support (default: 30)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
|
output_path = Path(os.path.expanduser(args.output)).resolve()
|
|
|
|
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
|
print(f"Saved converted checkpoint to {output_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|