diff --git a/README.md b/README.md index 00dede3..9fe1699 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ pip install whisperlivekit #### Quick Start 1. **Start the transcription server:** ```bash - whisperlivekit-server --model base --language en + wlk --model base --language en ``` 2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time! @@ -53,6 +53,7 @@ pip install whisperlivekit > - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages. > - For HTTPS requirements, see the **Parameters** section for SSL configuration options. +> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent. #### Use it to capture audio from web pages. @@ -85,10 +86,10 @@ See **Parameters & Configuration** below on how to use them. ```bash # Large model and translate from french to danish -whisperlivekit-server --model large-v3 --language fr --target-language da +wlk --model large-v3 --language fr --target-language da # Diarization and server listening on */80 -whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr +wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr ``` @@ -139,7 +140,7 @@ async def websocket_endpoint(websocket: WebSocket): | Parameter | Description | Default | |-----------|-------------|---------| | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` | -| `--model-path` | .pt file/directory containing whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` | +| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` | | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--diarization` | Enable speaker identification | `False` | diff --git a/docs/models_compatible_formats.md b/docs/models_compatible_formats.md index e6a760b..2559129 100644 --- a/docs/models_compatible_formats.md +++ b/docs/models_compatible_formats.md @@ -13,4 +13,7 @@ May optionally contain: - **`.bin` file** - faster-whisper model for encoder (requires faster-whisper) - **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx) -To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder. \ No newline at end of file +## Hugging Face Repo ID +- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first. + +To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder. diff --git a/pyproject.toml b/pyproject.toml index 0ed41ce..e86f10c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "websockets", "torchaudio>=2.0.0", "torch>=2.0.0", + "huggingface-hub>=0.25.0", "tqdm", "tiktoken", 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' @@ -48,6 +49,7 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" [project.scripts] whisperlivekit-server = "whisperlivekit.basic_server:main" +wlk = "whisperlivekit.basic_server:main" [tool.setuptools] packages = [ diff --git a/scripts/convert_hf_whisper.py b/scripts/convert_hf_whisper.py new file mode 100644 index 0000000..50352e8 --- /dev/null +++ b/scripts/convert_hf_whisper.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file. + +Optionally shrink the supported audio chunk length (in seconds) by trimming the +encoder positional embeddings and updating the stored model dimensions. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Dict, Tuple + +import torch + +from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE +from whisperlivekit.whisper.model import ModelDimensions +from whisperlivekit.whisper.utils import exact_div +from whisperlivekit.whisper import _convert_hf_state_dict + + +def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]: + safetensor_path = repo_path / "model.safetensors" + bin_path = repo_path / "pytorch_model.bin" + + if safetensor_path.is_file(): + try: + from safetensors.torch import load_file # type: ignore + except Exception as exc: # pragma: no cover - import guard + raise RuntimeError( + "Install safetensors to load model.safetensors " + "(pip install safetensors)" + ) from exc + return load_file(str(safetensor_path)) + + if bin_path.is_file(): + return torch.load(bin_path, map_location="cpu") + + raise FileNotFoundError( + f"Could not find model.safetensors or pytorch_model.bin under {repo_path}" + ) + + +def _load_config(repo_path: Path) -> Dict: + config_path = repo_path / "config.json" + if not config_path.is_file(): + raise FileNotFoundError( + f"Hugging Face checkpoint at {repo_path} is missing config.json" + ) + with open(config_path, "r", encoding="utf-8") as fp: + return json.load(fp) + + +def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]: + n_samples = int(round(chunk_length * SAMPLE_RATE)) + expected_samples = chunk_length * SAMPLE_RATE + if abs(n_samples - expected_samples) > 1e-6: + raise ValueError( + "chunk_length must align with sample rate so that " + "chunk_length * SAMPLE_RATE is an integer" + ) + n_frames = exact_div(n_samples, HOP_LENGTH) + n_audio_ctx = exact_div(n_frames, 2) + return n_frames, n_audio_ctx + + +def _build_dims(config: Dict, chunk_length: float) -> Dict: + base_dims = ModelDimensions( + n_mels=config["num_mel_bins"], + n_audio_ctx=config["max_source_positions"], + n_audio_state=config["d_model"], + n_audio_head=config["encoder_attention_heads"], + n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"], + n_vocab=config["vocab_size"], + n_text_ctx=config["max_target_positions"], + n_text_state=config["d_model"], + n_text_head=config["decoder_attention_heads"], + n_text_layer=config["decoder_layers"], + ).__dict__.copy() + + _, n_audio_ctx = _derive_audio_ctx(chunk_length) + base_dims["n_audio_ctx"] = n_audio_ctx + base_dims["chunk_length"] = chunk_length + return base_dims + + +def _trim_positional_embedding( + state_dict: Dict[str, torch.Tensor], target_ctx: int +) -> None: + key = "encoder.positional_embedding" + if key not in state_dict: + raise KeyError(f"{key} missing from converted state dict") + + tensor = state_dict[key] + if tensor.shape[0] < target_ctx: + raise ValueError( + f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}" + ) + if tensor.shape[0] == target_ctx: + return + state_dict[key] = tensor[:target_ctx].contiguous() + + +def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None: + state_dict = _load_state_dict(hf_path) + converted = _convert_hf_state_dict(state_dict) + + config = _load_config(hf_path) + dims = _build_dims(config, chunk_length) + + _trim_positional_embedding(converted, dims["n_audio_ctx"]) + + package = {"dims": dims, "model_state_dict": converted} + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(package, output_path) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format." + ) + parser.add_argument( + "hf_path", + type=str, + help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)", + ) + parser.add_argument( + "--output", + type=str, + default="converted-whisper.pt", + help="Destination path for the .pt file", + ) + parser.add_argument( + "--chunk-length", + type=float, + default=30.0, + help="Audio chunk length in seconds to support (default: 30)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + hf_path = Path(os.path.expanduser(args.hf_path)).resolve() + output_path = Path(os.path.expanduser(args.output)).resolve() + + convert_checkpoint(hf_path, output_path, args.chunk_length) + print(f"Saved converted checkpoint to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index dbfbd1f..8342c32 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -1,8 +1,7 @@ import sys import numpy as np import logging -from typing import List, Tuple, Optional -import logging +from typing import List, Tuple, Optional, Union import platform from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker from whisperlivekit.warmup import load_file @@ -37,7 +36,7 @@ else: print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""") HAS_FASTER_WHISPER = False -def model_path_and_type(model_path): +def model_path_and_type(model_path: Union[str, Path]): path = Path(model_path) compatible_whisper_mlx = False @@ -57,11 +56,28 @@ def model_path_and_type(model_path): elif file.suffix.lower() == '.safetensors': pytorch_path = file if pytorch_path is None: - if (model_path / Path("pytorch_model.bin")).exists(): - pytorch_path = model_path / Path("pytorch_model.bin") + if (path / Path("pytorch_model.bin")).exists(): + pytorch_path = path / Path("pytorch_model.bin") return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper +def resolve_model_path(model_path: str) -> Path: + path = Path(model_path) + if path.exists(): + return path + + try: + from huggingface_hub import snapshot_download + except ImportError as exc: + raise FileNotFoundError( + f"Model path '{model_path}' does not exist locally and huggingface_hub " + "is not installed to download it." + ) from exc + + downloaded_path = Path(snapshot_download(repo_id=model_path)) + return downloaded_path + + class SimulStreamingOnlineProcessor: SAMPLING_RATE = 16000 @@ -180,7 +196,9 @@ class SimulStreamingASR(): self.fast_encoder = False self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True if self.model_path: - self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path) + resolved_model_path = resolve_model_path(self.model_path) + self.model_path = str(resolved_model_path) + self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path) self.model_name = self.pytorch_path.stem is_multilingual = not self.model_path.endswith(".en") elif self.model_size is not None: