mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 06:14:05 +00:00
hf compatibility
This commit is contained in:
@@ -45,7 +45,7 @@ pip install whisperlivekit
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
wlk --model base --language en
|
||||
```
|
||||
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
@@ -53,6 +53,7 @@ pip install whisperlivekit
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
@@ -85,10 +86,10 @@ See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
```bash
|
||||
# Large model and translate from french to danish
|
||||
whisperlivekit-server --model large-v3 --language fr --target-language da
|
||||
wlk --model large-v3 --language fr --target-language da
|
||||
|
||||
# Diarization and server listening on */80
|
||||
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
|
||||
@@ -139,7 +140,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | .pt file/directory containing whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
|
||||
@@ -13,4 +13,7 @@ May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
## Hugging Face Repo ID
|
||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
|
||||
@@ -34,6 +34,7 @@ dependencies = [
|
||||
"websockets",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"huggingface-hub>=0.25.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
@@ -48,6 +49,7 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
wlk = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
|
||||
153
scripts/convert_hf_whisper.py
Normal file
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||
|
||||
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||
encoder positional embeddings and updating the stored model dimensions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
safetensor_path = repo_path / "model.safetensors"
|
||||
bin_path = repo_path / "pytorch_model.bin"
|
||||
|
||||
if safetensor_path.is_file():
|
||||
try:
|
||||
from safetensors.torch import load_file # type: ignore
|
||||
except Exception as exc: # pragma: no cover - import guard
|
||||
raise RuntimeError(
|
||||
"Install safetensors to load model.safetensors "
|
||||
"(pip install safetensors)"
|
||||
) from exc
|
||||
return load_file(str(safetensor_path))
|
||||
|
||||
if bin_path.is_file():
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||
)
|
||||
|
||||
|
||||
def _load_config(repo_path: Path) -> Dict:
|
||||
config_path = repo_path / "config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||
)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||
expected_samples = chunk_length * SAMPLE_RATE
|
||||
if abs(n_samples - expected_samples) > 1e-6:
|
||||
raise ValueError(
|
||||
"chunk_length must align with sample rate so that "
|
||||
"chunk_length * SAMPLE_RATE is an integer"
|
||||
)
|
||||
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||
n_audio_ctx = exact_div(n_frames, 2)
|
||||
return n_frames, n_audio_ctx
|
||||
|
||||
|
||||
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||
base_dims = ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
).__dict__.copy()
|
||||
|
||||
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||
base_dims["chunk_length"] = chunk_length
|
||||
return base_dims
|
||||
|
||||
|
||||
def _trim_positional_embedding(
|
||||
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||
) -> None:
|
||||
key = "encoder.positional_embedding"
|
||||
if key not in state_dict:
|
||||
raise KeyError(f"{key} missing from converted state dict")
|
||||
|
||||
tensor = state_dict[key]
|
||||
if tensor.shape[0] < target_ctx:
|
||||
raise ValueError(
|
||||
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||
)
|
||||
if tensor.shape[0] == target_ctx:
|
||||
return
|
||||
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||
|
||||
|
||||
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||
state_dict = _load_state_dict(hf_path)
|
||||
converted = _convert_hf_state_dict(state_dict)
|
||||
|
||||
config = _load_config(hf_path)
|
||||
dims = _build_dims(config, chunk_length)
|
||||
|
||||
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||
|
||||
package = {"dims": dims, "model_state_dict": converted}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(package, output_path)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"hf_path",
|
||||
type=str,
|
||||
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="converted-whisper.pt",
|
||||
help="Destination path for the .pt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Audio chunk length in seconds to support (default: 30)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||
|
||||
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||
print(f"Saved converted checkpoint to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,7 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
from typing import List, Tuple, Optional, Union
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
@@ -37,7 +36,7 @@ else:
|
||||
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""")
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
def model_path_and_type(model_path):
|
||||
def model_path_and_type(model_path: Union[str, Path]):
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
@@ -57,11 +56,28 @@ def model_path_and_type(model_path):
|
||||
elif file.suffix.lower() == '.safetensors':
|
||||
pytorch_path = file
|
||||
if pytorch_path is None:
|
||||
if (model_path / Path("pytorch_model.bin")).exists():
|
||||
pytorch_path = model_path / Path("pytorch_model.bin")
|
||||
if (path / Path("pytorch_model.bin")).exists():
|
||||
pytorch_path = path / Path("pytorch_model.bin")
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: str) -> Path:
|
||||
path = Path(model_path)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=model_path))
|
||||
return downloaded_path
|
||||
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
@@ -180,7 +196,9 @@ class SimulStreamingASR():
|
||||
self.fast_encoder = False
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
self.model_name = self.pytorch_path.stem
|
||||
is_multilingual = not self.model_path.endswith(".en")
|
||||
elif self.model_size is not None:
|
||||
|
||||
Reference in New Issue
Block a user